Skip to content

Commit 1668800

Browse files
[Async] Apply #2240 to async code
1 parent d3654a6 commit 1668800

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

src/snowflake/connector/aio/_cursor.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
from snowflake.connector.telemetry import TelemetryData, TelemetryField
5757
from snowflake.connector.time_util import get_time_millis
5858

59+
from .._utils import REQUEST_ID_STATEMENT_PARAM_NAME, is_uuid4
60+
5961
if TYPE_CHECKING:
6062
from pandas import DataFrame
6163
from pyarrow import Table
@@ -202,7 +204,27 @@ async def _execute_helper(
202204
)
203205

204206
self._sequence_counter = await self._connection._next_sequence_counter()
205-
self._request_id = uuid.uuid4()
207+
208+
# If requestId is contained in statement parameters, use it to set request id. Verify here it is a valid uuid4
209+
# identifier.
210+
if (
211+
statement_params is not None
212+
and REQUEST_ID_STATEMENT_PARAM_NAME in statement_params
213+
):
214+
request_id = statement_params[REQUEST_ID_STATEMENT_PARAM_NAME]
215+
216+
if not is_uuid4(request_id):
217+
# uuid.UUID will throw an error if invalid, but we explicitly check and throw here.
218+
raise ValueError(f"requestId {request_id} is not a valid UUID4.")
219+
self._request_id = uuid.UUID(str(request_id), version=4)
220+
221+
# Create a (deep copy) and remove the statement param, there is no need to encode it as extra parameter
222+
# one more time.
223+
statement_params = statement_params.copy()
224+
statement_params.pop(REQUEST_ID_STATEMENT_PARAM_NAME)
225+
else:
226+
# Generate UUID for query.
227+
self._request_id = uuid.uuid4()
206228

207229
logger.debug(f"Request id: {self._request_id}")
208230

test/integ/aio/test_cursor_async.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import pickle
1414
import time
15+
import uuid
1516
from datetime import date, datetime, timezone
1617
from typing import NamedTuple
1718
from unittest import mock
@@ -1921,3 +1922,37 @@ async def test_fetch_download_timeout_setting(conn_cnx):
19211922
sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v"
19221923
async with conn_cnx() as con, con.cursor() as cur:
19231924
assert len(await (await cur.execute(sql)).fetchall()) == 100000
1925+
1926+
1927+
@pytest.mark.parametrize(
1928+
"request_id",
1929+
[
1930+
"THIS IS NOT VALID",
1931+
uuid.uuid1(),
1932+
uuid.uuid3(uuid.NAMESPACE_URL, "www.snowflake.com"),
1933+
uuid.uuid5(uuid.NAMESPACE_URL, "www.snowflake.com"),
1934+
],
1935+
)
1936+
async def test_custom_request_id_negative(request_id, conn_cnx):
1937+
1938+
# Ensure that invalid request_ids (non uuid4) do not compromise interface.
1939+
with pytest.raises(ValueError, match="requestId"):
1940+
async with conn_cnx() as con:
1941+
async with con.cursor() as cur:
1942+
await cur.execute(
1943+
"select seq4() as foo from table(generator(rowcount=>5))",
1944+
_statement_params={"requestId": request_id},
1945+
)
1946+
1947+
1948+
async def test_custom_request_id(conn_cnx):
1949+
request_id = uuid.uuid4()
1950+
1951+
async with conn_cnx() as con:
1952+
async with con.cursor() as cur:
1953+
await cur.execute(
1954+
"select seq4() as foo from table(generator(rowcount=>5))",
1955+
_statement_params={"requestId": request_id},
1956+
)
1957+
1958+
assert cur._sfqid is not None, "Query must execute successfully."

0 commit comments

Comments
 (0)