6
6
7
7
import asyncio
8
8
import collections
9
+ import logging
9
10
import re
10
11
import signal
11
12
import sys
13
+ import typing
12
14
import uuid
13
15
from logging import getLogger
14
16
from types import TracebackType
30
32
create_batches_from_response ,
31
33
)
32
34
from snowflake .connector .aio ._result_set import ResultSet , ResultSetIterator
33
- from snowflake .connector .constants import PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT
34
- from snowflake .connector .cursor import DESC_TABLE_RE
35
+ from snowflake .connector .constants import (
36
+ PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT ,
37
+ QueryStatus ,
38
+ )
39
+ from snowflake .connector .cursor import (
40
+ ASYNC_NO_DATA_MAX_RETRY ,
41
+ ASYNC_RETRY_PATTERN ,
42
+ DESC_TABLE_RE ,
43
+ )
35
44
from snowflake .connector .cursor import DictCursor as DictCursorSync
36
45
from snowflake .connector .cursor import ResultMetadata , ResultMetadataV2 , ResultState
37
46
from snowflake .connector .cursor import SnowflakeCursor as SnowflakeCursorSync
43
52
ER_INVALID_VALUE ,
44
53
ER_NOT_POSITIVE_SIZE ,
45
54
)
46
- from snowflake .connector .errors import BindUploadError
55
+ from snowflake .connector .errors import BindUploadError , DatabaseError
47
56
from snowflake .connector .file_transfer_agent import SnowflakeProgressPercentage
48
57
from snowflake .connector .telemetry import TelemetryField
49
58
from snowflake .connector .time_util import get_time_millis
@@ -65,9 +74,11 @@ def __init__(
65
74
):
66
75
super ().__init__ (connection , use_dict_result )
67
76
# the following fixes type hint
68
- self ._connection : SnowflakeConnection = connection
77
+ self ._connection = typing .cast ("SnowflakeConnection" , self ._connection )
78
+ self ._inner_cursor = typing .cast (SnowflakeCursor , self ._inner_cursor )
69
79
self ._lock_canceling = asyncio .Lock ()
70
80
self ._timebomb : asyncio .Task | None = None
81
+ self ._prefetch_hook : typing .Callable [[], typing .Awaitable ] | None = None
71
82
72
83
def __aiter__ (self ):
73
84
return self
@@ -87,6 +98,18 @@ async def __anext__(self):
87
98
async def __aenter__ (self ):
88
99
return self
89
100
101
+ def __enter__ (self ):
102
+ # async cursor does not support sync context manager
103
+ raise TypeError (
104
+ "'SnowflakeCursor' object does not support the context manager protocol"
105
+ )
106
+
107
+ def __exit__ (self , exc_type , exc_val , exc_tb ):
108
+ # async cursor does not support sync context manager
109
+ raise TypeError (
110
+ "'SnowflakeCursor' object does not support the context manager protocol"
111
+ )
112
+
90
113
def __del__ (self ):
91
114
# do nothing in async, __del__ is unreliable
92
115
pass
@@ -337,6 +360,7 @@ async def _init_result_and_meta(self, data: dict[Any, Any]) -> None:
337
360
self ._total_rowcount += updated_rows
338
361
339
362
async def _init_multi_statement_results (self , data : dict ) -> None :
363
+ # TODO: async telemetry SNOW-1572217
340
364
# self._log_telemetry_job_data(TelemetryField.MULTI_STATEMENT, TelemetryData.TRUE)
341
365
self .multi_statement_savedIds = data ["resultIds" ].split ("," )
342
366
self ._multi_statement_resultIds = collections .deque (
@@ -357,7 +381,45 @@ async def _init_multi_statement_results(self, data: dict) -> None:
357
381
async def _log_telemetry_job_data (
358
382
self , telemetry_field : TelemetryField , value : Any
359
383
) -> None :
360
- raise NotImplementedError ("Telemetry is not supported in async." )
384
+ # TODO: async telemetry SNOW-1572217
385
+ pass
386
+
387
+ async def _preprocess_pyformat_query (
388
+ self ,
389
+ command : str ,
390
+ params : Sequence [Any ] | dict [Any , Any ] | None = None ,
391
+ ) -> str :
392
+ # pyformat/format paramstyle
393
+ # client side binding
394
+ processed_params = self ._connection ._process_params_pyformat (params , self )
395
+ # SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement
396
+ # TODO: async telemetry support
397
+ # if params is not None and len(params) == 0:
398
+ # await self._log_telemetry_job_data(
399
+ # TelemetryField.EMPTY_SEQ_INTERPOLATION,
400
+ # (
401
+ # TelemetryData.TRUE
402
+ # if self.connection._interpolate_empty_sequences
403
+ # else TelemetryData.FALSE
404
+ # ),
405
+ # )
406
+ if logger .getEffectiveLevel () <= logging .DEBUG :
407
+ logger .debug (
408
+ f"binding: [{ self ._format_query_for_log (command )} ] "
409
+ f"with input=[{ params } ], "
410
+ f"processed=[{ processed_params } ]" ,
411
+ )
412
+ if (
413
+ self .connection ._interpolate_empty_sequences
414
+ and processed_params is not None
415
+ ) or (
416
+ not self .connection ._interpolate_empty_sequences
417
+ and len (processed_params ) > 0
418
+ ):
419
+ query = command % processed_params
420
+ else :
421
+ query = command
422
+ return query
361
423
362
424
async def abort_query (self , qid : str ) -> bool :
363
425
url = f"/queries/{ qid } /abort-request"
@@ -387,6 +449,10 @@ async def callproc(self, procname: str, args=tuple()):
387
449
await self .execute (command , args )
388
450
return args
389
451
452
+ @property
453
+ def connection (self ) -> SnowflakeConnection :
454
+ return self ._connection
455
+
390
456
async def close (self ):
391
457
"""Closes the cursor object.
392
458
@@ -471,7 +537,7 @@ async def execute(
471
537
}
472
538
473
539
if self ._connection .is_pyformat :
474
- query = self ._preprocess_pyformat_query (command , params )
540
+ query = await self ._preprocess_pyformat_query (command , params )
475
541
else :
476
542
# qmark and numeric paramstyle
477
543
query = command
@@ -538,7 +604,7 @@ async def execute(
538
604
self ._connection .converter .set_parameter (param , value )
539
605
540
606
if "resultIds" in data :
541
- self ._init_multi_statement_results (data )
607
+ await self ._init_multi_statement_results (data )
542
608
return self
543
609
else :
544
610
self .multi_statement_savedIds = []
@@ -707,7 +773,7 @@ async def executemany(
707
773
command = command + "; "
708
774
if self ._connection .is_pyformat :
709
775
processed_queries = [
710
- self ._preprocess_pyformat_query (command , params )
776
+ await self ._preprocess_pyformat_query (command , params )
711
777
for params in seqparams
712
778
]
713
779
query = "" .join (processed_queries )
@@ -752,7 +818,7 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
752
818
async def fetchone (self ) -> dict | tuple | None :
753
819
"""Fetches one row."""
754
820
if self ._prefetch_hook is not None :
755
- self ._prefetch_hook ()
821
+ await self ._prefetch_hook ()
756
822
if self ._result is None and self ._result_set is not None :
757
823
self ._result : ResultSetIterator = await self ._result_set ._create_iter ()
758
824
self ._result_state = ResultState .VALID
@@ -804,7 +870,7 @@ async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
804
870
async def fetchall (self ) -> list [tuple ] | list [dict ]:
805
871
"""Fetches all of the results."""
806
872
if self ._prefetch_hook is not None :
807
- self ._prefetch_hook ()
873
+ await self ._prefetch_hook ()
808
874
if self ._result is None and self ._result_set is not None :
809
875
self ._result : ResultSetIterator = await self ._result_set ._create_iter (
810
876
is_fetch_all = True
@@ -822,9 +888,10 @@ async def fetchall(self) -> list[tuple] | list[dict]:
822
888
async def fetch_arrow_batches (self ) -> AsyncIterator [Table ]:
823
889
self .check_can_use_arrow_resultset ()
824
890
if self ._prefetch_hook is not None :
825
- self ._prefetch_hook ()
891
+ await self ._prefetch_hook ()
826
892
if self ._query_result_format != "arrow" :
827
893
raise NotSupportedError
894
+ # TODO: async telemetry SNOW-1572217
828
895
# self._log_telemetry_job_data(
829
896
# TelemetryField.ARROW_FETCH_BATCHES, TelemetryData.TRUE
830
897
# )
@@ -848,9 +915,10 @@ async def fetch_arrow_all(self, force_return_table: bool = False) -> Table | Non
848
915
self .check_can_use_arrow_resultset ()
849
916
850
917
if self ._prefetch_hook is not None :
851
- self ._prefetch_hook ()
918
+ await self ._prefetch_hook ()
852
919
if self ._query_result_format != "arrow" :
853
920
raise NotSupportedError
921
+ # TODO: async telemetry SNOW-1572217
854
922
# self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE)
855
923
return await self ._result_set ._fetch_arrow_all (
856
924
force_return_table = force_return_table
@@ -860,7 +928,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]:
860
928
"""Fetches a single Arrow Table."""
861
929
self .check_can_use_pandas ()
862
930
if self ._prefetch_hook is not None :
863
- self ._prefetch_hook ()
931
+ await self ._prefetch_hook ()
864
932
if self ._query_result_format != "arrow" :
865
933
raise NotSupportedError
866
934
# TODO: async telemetry
@@ -872,7 +940,7 @@ async def fetch_pandas_batches(self, **kwargs: Any) -> AsyncIterator[DataFrame]:
872
940
async def fetch_pandas_all (self , ** kwargs : Any ) -> DataFrame :
873
941
self .check_can_use_pandas ()
874
942
if self ._prefetch_hook is not None :
875
- self ._prefetch_hook ()
943
+ await self ._prefetch_hook ()
876
944
if self ._query_result_format != "arrow" :
877
945
raise NotSupportedError
878
946
# # TODO: async telemetry
@@ -917,8 +985,70 @@ async def get_result_batches(self) -> list[ResultBatch] | None:
917
985
return self ._result_set .batches
918
986
919
987
async def get_results_from_sfqid (self , sfqid : str ) -> None :
920
- """Gets the results from previously ran query."""
921
- raise NotImplementedError ("Not implemented in async" )
988
+ """Gets the results from previously ran query. This methods differs from ``SnowflakeCursor.query_result``
989
+ in that it monitors the ``sfqid`` until it is no longer running, and then retrieves the results.
990
+ """
991
+
992
+ async def wait_until_ready () -> None :
993
+ """Makes sure query has finished executing and once it has retrieves results."""
994
+ no_data_counter = 0
995
+ retry_pattern_pos = 0
996
+ while True :
997
+ status , status_resp = await self .connection ._get_query_status (sfqid )
998
+ self .connection ._cache_query_status (sfqid , status )
999
+ if not self .connection .is_still_running (status ):
1000
+ break
1001
+ if status == QueryStatus .NO_DATA : # pragma: no cover
1002
+ no_data_counter += 1
1003
+ if no_data_counter > ASYNC_NO_DATA_MAX_RETRY :
1004
+ raise DatabaseError (
1005
+ "Cannot retrieve data on the status of this query. No information returned "
1006
+ "from server for query '{}'"
1007
+ )
1008
+ await asyncio .sleep (
1009
+ 0.5 * ASYNC_RETRY_PATTERN [retry_pattern_pos ]
1010
+ ) # Same wait as JDBC
1011
+ # If we can advance in ASYNC_RETRY_PATTERN then do so
1012
+ if retry_pattern_pos < (len (ASYNC_RETRY_PATTERN ) - 1 ):
1013
+ retry_pattern_pos += 1
1014
+ if status != QueryStatus .SUCCESS :
1015
+ logger .info (f"Status of query '{ sfqid } ' is { status .name } " )
1016
+ self .connection ._process_error_query_status (
1017
+ sfqid ,
1018
+ status_resp ,
1019
+ error_message = f"Status of query '{ sfqid } ' is { status .name } , results are unavailable" ,
1020
+ error_cls = DatabaseError ,
1021
+ )
1022
+ await self ._inner_cursor .execute (
1023
+ f"select * from table(result_scan('{ sfqid } '))"
1024
+ )
1025
+ self ._result = self ._inner_cursor ._result
1026
+ self ._query_result_format = self ._inner_cursor ._query_result_format
1027
+ self ._total_rowcount = self ._inner_cursor ._total_rowcount
1028
+ self ._description = self ._inner_cursor ._description
1029
+ self ._result_set = self ._inner_cursor ._result_set
1030
+ self ._result_state = ResultState .VALID
1031
+ self ._rownumber = 0
1032
+ # Unset this function, so that we don't block anymore
1033
+ self ._prefetch_hook = None
1034
+
1035
+ if (
1036
+ self ._inner_cursor ._total_rowcount == 1
1037
+ and await self ._inner_cursor .fetchall ()
1038
+ == [("Multiple statements executed successfully." ,)]
1039
+ ):
1040
+ url = f"/queries/{ sfqid } /result"
1041
+ ret = await self ._connection .rest .request (url = url , method = "get" )
1042
+ if "data" in ret and "resultIds" in ret ["data" ]:
1043
+ await self ._init_multi_statement_results (ret ["data" ])
1044
+
1045
+ await self .connection .get_query_status_throw_if_error (
1046
+ sfqid
1047
+ ) # Trigger an exception if query failed
1048
+ klass = self .__class__
1049
+ self ._inner_cursor = klass (self .connection )
1050
+ self ._sfqid = sfqid
1051
+ self ._prefetch_hook = wait_until_ready
922
1052
923
1053
async def query_result (self , qid : str ) -> SnowflakeCursor :
924
1054
url = f"/queries/{ qid } /result"
0 commit comments