Skip to content

Commit 9bc2e1f

Browse files
SNOW-199087: Implement multi-statement support for the Python connector. (#1281)
* SNOW-199087: Implement multi-statement support for the Python connector. SNOW-676491: Executing multiple queries with no complications. SNOW-676495: Execute multiple queries with parameter bindings. * SNOW-676497: Added asynchronous execution support for multi-statements. * SNOW-676496: Adding multi-statement support to executemany and addressing some comments. * Fixing imports for olddriver test. * Addressed comments, and improved multi-statement tests * Adding telemetry, fixed tests, detect PUT/GET from server response. * Old driver test import failure fix * Add type hint for multi-statement query ID deque * Add type hint to num_statements from kwargs * Addressing more comments, bumping version for release * Address comment in DESCRIPTION.md Co-authored-by: Mark Keller <[email protected]>
1 parent adb34d6 commit 9bc2e1f

File tree

8 files changed

+605
-53
lines changed

8 files changed

+605
-53
lines changed

DESCRIPTION.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,22 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
88

99
# Release Notes
1010

11-
- v2.8.4(unreleased)
11+
- v2.9.0(Unreleased)
12+
1213
- Fixed a bug where the permission of the file downloaded via GET command is changed
1314
- Reworked authentication internals to allow users to plug custom key-pair authenticators
15+
- Multi-statement query execution is now supported through `cursor.execute` and `cursor.executemany`
16+
- The Snowflake parameter `MULTI_STATEMENT_COUNT` can be altered at the account, session, or statement level. An additional argument, `num_statements`, can be provided to `execute` to use this parameter at the statement level. It *must* be provided to `executemany` to submit a multi-statement query through the method. Note that bulk insert optimizations available through `executemany` are not available when submitting multi-statement queries.
17+
- By default the parameter is 1, meaning only a single query can be submitted at a time
18+
- Set to 0 to submit any number of statements in a multi-statement query
19+
- Set to >1 to submit the specified exact number of statements in a multi-statement query
20+
- Bindings are accepted in the same way for multi-statements as they are for single statement queries
21+
- Asynchronous multi-statement query execution is supported. Users should still use `get_results_from_sfqid` to retrieve results
22+
- To access the results of each query, users can call `SnowflakeCursor.nextset()` as specified in the DB 2.0 API (PEP-249), to iterate through each statements results
23+
- The first statement's results are accessible immediately after calling `execute` (or `get_results_from_sfqid` if asynchronous) through the existing `fetch*()` methods
1424

1525
- v2.8.3(November 28,2022)
26+
1627
- Bumped cryptography dependency from <39.0.0 to <41.0.0
1728
- Fixed a bug where expired OCSP response cache caused infinite recursion during cache loading
1829

src/snowflake/connector/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class FileHeader(NamedTuple):
229229
PARAMETER_ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1 = (
230230
"ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1"
231231
)
232+
PARAMETER_MULTI_STATEMENT_COUNT = "MULTI_STATEMENT_COUNT"
232233

233234
HTTP_HEADER_CONTENT_TYPE = "Content-Type"
234235
HTTP_HEADER_CONTENT_ENCODING = "Content-Encoding"

src/snowflake/connector/cursor.py

Lines changed: 139 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import sys
1313
import time
1414
import uuid
15+
from collections import deque
1516
from enum import Enum
1617
from logging import getLogger
1718
from threading import Lock, Timer
@@ -190,7 +191,7 @@ class SnowflakeCursor:
190191
r".*VALUES\s*(\(.*\)).*", re.IGNORECASE | re.MULTILINE | re.DOTALL
191192
)
192193
ALTER_SESSION_RE = re.compile(
193-
r"alter\s+session\s+set\s+(.*)=\'?([^\']+)\'?\s*;",
194+
r"alter\s+session\s+set\s+(\w*?)\s*=\s*\'?([^\']+?)\'?\s*(?:;|$)",
194195
flags=re.IGNORECASE | re.MULTILINE | re.DOTALL,
195196
)
196197

@@ -232,6 +233,8 @@ def __init__(
232233
self._sequence_counter = -1
233234
self._request_id = None
234235
self._is_file_transfer = False
236+
self._multi_statement_resultIds: deque[str] = deque()
237+
self.multi_statement_savedIds: list[str] = []
235238

236239
self._timestamp_output_format = None
237240
self._timestamp_ltz_output_format = None
@@ -559,6 +562,40 @@ def interrupt_handler(*_): # pragma: no cover
559562
self._sequence_counter = -1
560563
return ret
561564

565+
def _preprocess_pyformat_query(
566+
self,
567+
command: str,
568+
params: Sequence[Any] | dict[Any, Any] | None = None,
569+
) -> str:
570+
# pyformat/format paramstyle
571+
# client side binding
572+
processed_params = self._connection._process_params_pyformat(params, self)
573+
# SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement
574+
if params is not None and len(params) == 0:
575+
self._log_telemetry_job_data(
576+
TelemetryField.EMPTY_SEQ_INTERPOLATION,
577+
TelemetryData.TRUE
578+
if self.connection._interpolate_empty_sequences
579+
else TelemetryData.FALSE,
580+
)
581+
if logger.getEffectiveLevel() <= logging.DEBUG:
582+
logger.debug(
583+
f"binding: [{self._format_query_for_log(command)}] "
584+
f"with input=[{params}], "
585+
f"processed=[{processed_params}]",
586+
)
587+
if (
588+
self.connection._interpolate_empty_sequences
589+
and processed_params is not None
590+
) or (
591+
not self.connection._interpolate_empty_sequences
592+
and len(processed_params) > 0
593+
):
594+
query = command % processed_params
595+
else:
596+
query = command
597+
return query
598+
562599
def execute(
563600
self,
564601
command: str,
@@ -583,6 +620,7 @@ def execute(
583620
_raise_put_get_error: bool = True,
584621
_force_put_overwrite: bool = False,
585622
file_stream: IO[bytes] | None = None,
623+
num_statements: int | None = None,
586624
) -> SnowflakeCursor | None:
587625
"""Executes a command/query.
588626
@@ -612,6 +650,8 @@ def execute(
612650
_force_put_overwrite: If the SQL query is a PUT, then this flag can force overwriting of an already
613651
existing file on stage.
614652
file_stream: File-like object to be uploaded with PUT
653+
num_statements: Query level parameter submitted in _statement_params constraining exact number of
654+
statements being submitted (or 0 if submitting an uncounted number) when using a multi-statement query.
615655
616656
Returns:
617657
The cursor itself, or None if some error happened, or the response returned
@@ -635,6 +675,12 @@ def execute(
635675
logger.warning("execute: no query is given to execute")
636676
return
637677

678+
if _statement_params is None:
679+
_statement_params = dict()
680+
681+
if num_statements:
682+
_statement_params["MULTI_STATEMENT_COUNT"] = num_statements
683+
638684
kwargs = {
639685
"timeout": timeout,
640686
"statement_params": _statement_params,
@@ -646,33 +692,7 @@ def execute(
646692
}
647693

648694
if self._connection.is_pyformat:
649-
# pyformat/format paramstyle
650-
# client side binding
651-
processed_params = self._connection._process_params_pyformat(params, self)
652-
# SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement
653-
if params is not None and len(params) == 0:
654-
self._log_telemetry_job_data(
655-
TelemetryField.EMPTY_SEQ_INTERPOLATION,
656-
TelemetryData.TRUE
657-
if self.connection._interpolate_empty_sequences
658-
else TelemetryData.FALSE,
659-
)
660-
if logger.getEffectiveLevel() <= logging.DEBUG:
661-
logger.debug(
662-
f"binding: [{self._format_query_for_log(command)}] "
663-
f"with input=[{params}], "
664-
f"processed=[{processed_params}]",
665-
)
666-
if (
667-
self.connection._interpolate_empty_sequences
668-
and processed_params is not None
669-
) or (
670-
not self.connection._interpolate_empty_sequences
671-
and len(processed_params) > 0
672-
):
673-
query = command % processed_params
674-
else:
675-
query = command
695+
query = self._preprocess_pyformat_query(command, params)
676696
else:
677697
# qmark and numeric paramstyle
678698
query = command
@@ -711,11 +731,14 @@ def execute(
711731
if "data" in ret and "queryId" in ret["data"]
712732
else None
713733
)
734+
logger.debug(f"sfqid: {self.sfqid}")
714735
self._sqlstate = (
715736
ret["data"]["sqlState"]
716737
if "data" in ret and "sqlState" in ret["data"]
717738
else None
718739
)
740+
logger.info("query execution done")
741+
719742
self._first_chunk_time = get_time_millis()
720743

721744
# if server gives a send time, log the time it took to arrive
@@ -726,13 +749,27 @@ def execute(
726749
self._log_telemetry_job_data(
727750
TelemetryField.TIME_CONSUME_FIRST_RESULT, time_consume_first_result
728751
)
729-
logger.debug("sfqid: %s", self.sfqid)
730752

731-
logger.info("query execution done")
732753
if ret["success"]:
733754
logger.debug("SUCCESS")
734755
data = ret["data"]
735756

757+
for m in self.ALTER_SESSION_RE.finditer(query):
758+
# session parameters
759+
param = m.group(1).upper()
760+
value = m.group(2)
761+
self._connection.converter.set_parameter(param, value)
762+
763+
if "resultIds" in data:
764+
self._init_multi_statement_results(data)
765+
return self
766+
else:
767+
self.multi_statement_savedIds = []
768+
769+
self._is_file_transfer = "command" in data and data["command"] in (
770+
"UPLOAD",
771+
"DOWNLOAD",
772+
)
736773
logger.debug("PUT OR GET: %s", self.is_file_transfer)
737774
if self.is_file_transfer:
738775
# Decide whether to use the old, or new code path
@@ -757,12 +794,6 @@ def execute(
757794
sf_file_transfer_agent.execute()
758795
data = sf_file_transfer_agent.result()
759796
self._total_rowcount = len(data["rowset"]) if "rowset" in data else -1
760-
m = self.ALTER_SESSION_RE.match(query)
761-
if m:
762-
# session parameters
763-
param = m.group(1).upper()
764-
value = m.group(2)
765-
self._connection.converter.set_parameter(param, value)
766797

767798
if _exec_async:
768799
self.connection._async_sfqids[self._sfqid] = None
@@ -871,6 +902,22 @@ def _init_result_and_meta(self, data):
871902
else:
872903
self._total_rowcount += updated_rows
873904

905+
def _init_multi_statement_results(self, data: dict):
906+
self._log_telemetry_job_data(TelemetryField.MULTI_STATEMENT, TelemetryData.TRUE)
907+
self.multi_statement_savedIds = data["resultIds"].split(",")
908+
self._multi_statement_resultIds = deque(self.multi_statement_savedIds)
909+
if self._is_file_transfer:
910+
Error.errorhandler_wrapper(
911+
self.connection,
912+
self,
913+
ProgrammingError,
914+
{
915+
"msg": "PUT/GET commands are not supported for multi-statement queries and cannot be executed.",
916+
"errno": ER_INVALID_VALUE,
917+
},
918+
)
919+
self.nextset()
920+
874921
def check_can_use_arrow_resultset(self):
875922
global CAN_USE_ARROW_RESULT_FORMAT
876923

@@ -1002,10 +1049,17 @@ def executemany(
10021049
command = command.strip(" \t\n\r") if command else None
10031050

10041051
if not seqparams:
1052+
logger.warning(
1053+
"No parameters provided to executemany, returning without doing anything."
1054+
)
10051055
return self
10061056

1007-
if self.INSERT_SQL_RE.match(command):
1057+
if self.INSERT_SQL_RE.match(command) and (
1058+
"num_statements" not in kwargs or kwargs.get("num_statements") == 1
1059+
):
10081060
if self._connection.is_pyformat:
1061+
# TODO - utilize multi-statement instead of rewriting the query and
1062+
# accumulate results to mock the result from a single insert statement as formatted below
10091063
logger.debug("rewriting INSERT query")
10101064
command_wo_comments = re.sub(self.COMMENT_SQL_RE, "", command)
10111065
m = self.INSERT_SQL_VALUES_RE.match(command_wo_comments)
@@ -1074,8 +1128,31 @@ def executemany(
10741128
return self
10751129

10761130
self.reset()
1077-
for param in seqparams:
1078-
self.execute(command, params=param, _do_reset=False, **kwargs)
1131+
if "num_statements" not in kwargs:
1132+
# fall back to old driver behavior when the user does not provide the parameter to enable
1133+
# multi-statement optimizations for executemany
1134+
for param in seqparams:
1135+
self.execute(command, params=param, _do_reset=False, **kwargs)
1136+
else:
1137+
if re.search(";/s*$", command) is None:
1138+
command = command + "; "
1139+
if self._connection.is_pyformat:
1140+
processed_queries = [
1141+
self._preprocess_pyformat_query(command, params)
1142+
for params in seqparams
1143+
]
1144+
query = "".join(processed_queries)
1145+
params = None
1146+
else:
1147+
query = command * len(seqparams)
1148+
params = [param for parameters in seqparams for param in parameters]
1149+
1150+
kwargs["num_statements"]: int = kwargs.get("num_statements") * len(
1151+
seqparams
1152+
)
1153+
1154+
self.execute(query, params, _do_reset=False, **kwargs)
1155+
10791156
return self
10801157

10811158
def _result_iterator(
@@ -1147,8 +1224,19 @@ def fetchall(self) -> list[tuple] | list[dict]:
11471224
return ret
11481225

11491226
def nextset(self):
1150-
"""Not supported."""
1151-
logger.debug("nop")
1227+
"""
1228+
Fetches the next set of results if the previously executed query was multi-statement so that subsequent calls
1229+
to any of the fetch*() methods will return rows from the next query's set of results. Returns None if no more
1230+
query results are available.
1231+
"""
1232+
self.reset()
1233+
if self._multi_statement_resultIds:
1234+
self.query_result(self._multi_statement_resultIds[0])
1235+
logger.info(
1236+
f"Retrieved results for query ID: {self._multi_statement_resultIds.popleft()}"
1237+
)
1238+
return self
1239+
11521240
return None
11531241

11541242
def setinputsizes(self, _):
@@ -1276,6 +1364,16 @@ def wait_until_ready():
12761364
# Unset this function, so that we don't block anymore
12771365
self._prefetch_hook = None
12781366

1367+
if (
1368+
self._inner_cursor._total_rowcount == 1
1369+
and self._inner_cursor.fetchall()
1370+
== [("Multiple statements executed successfully.",)]
1371+
):
1372+
url = f"/queries/{sfqid}/result"
1373+
ret = self._connection.rest.request(url=url, method="get")
1374+
if "data" in ret and "resultIds" in ret["data"]:
1375+
self._init_multi_statement_results(ret["data"])
1376+
12791377
self.connection.get_query_status_throw_if_error(
12801378
sfqid
12811379
) # Trigger an exception if query failed

src/snowflake/connector/telemetry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class TelemetryField(Enum):
4141
PANDAS_WRITE = "client_write_pandas"
4242
# imported packages along with client
4343
IMPORTED_PACKAGES = "client_imported_packages"
44+
# multi-statement usage
45+
MULTI_STATEMENT = "client_multi_statement_query"
4446
# Keys for telemetry data sent through either in-band or out-of-band telemetry
4547
KEY_TYPE = "type"
4648
KEY_SOURCE = "source"

src/snowflake/connector/test_util.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
import logging
99
import os
10+
import time
11+
12+
import pytest
13+
14+
import snowflake.connector.connection
15+
from snowflake.connector.constants import QueryStatus
1016

1117
from .compat import IS_LINUX
1218

@@ -29,3 +35,37 @@
2935
)
3036
)
3137
rt_plain_logger.addHandler(ch)
38+
39+
40+
def _wait_while_query_running(
41+
con: snowflake.connector.connection.SnowflakeConnection,
42+
sfqid: str,
43+
sleep_time: int,
44+
dont_cache: bool = False,
45+
) -> None:
46+
"""
47+
Checks if the provided still returns that it is still running, and if so,
48+
sleeps for the specified time in a while loop.
49+
"""
50+
query_status = con._get_query_status if dont_cache else con.get_query_status
51+
while con.is_still_running(query_status(sfqid)):
52+
time.sleep(sleep_time)
53+
54+
55+
def _wait_until_query_success(
56+
con: snowflake.connector.connection.SnowflakeConnection,
57+
sfqid: str,
58+
num_checks: int,
59+
sleep_per_check: int,
60+
) -> None:
61+
for _ in range(num_checks):
62+
status = con.get_query_status(sfqid)
63+
if status == QueryStatus.SUCCESS:
64+
break
65+
time.sleep(sleep_per_check)
66+
else:
67+
pytest.fail(
68+
"We should have broke out of wait loop for query success."
69+
f"Query ID: {sfqid}"
70+
f"Final query status: {status}"
71+
)

0 commit comments

Comments
 (0)