Skip to content

Commit eda89f0

Browse files
[Async] Apply #2240 to async code
1 parent 7a531cd commit eda89f0

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

src/snowflake/connector/aio/_cursor.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
from snowflake.connector.telemetry import TelemetryData, TelemetryField
5757
from snowflake.connector.time_util import get_time_millis
5858

59+
from .._utils import REQUEST_ID_STATEMENT_PARAM_NAME, is_uuid4
60+
5961
if TYPE_CHECKING:
6062
from pandas import DataFrame
6163
from pyarrow import Table
@@ -202,7 +204,27 @@ async def _execute_helper(
202204
)
203205

204206
self._sequence_counter = await self._connection._next_sequence_counter()
205-
self._request_id = uuid.uuid4()
207+
208+
# If requestId is contained in statement parameters, use it to set request id. Verify here it is a valid uuid4
209+
# identifier.
210+
if (
211+
statement_params is not None
212+
and REQUEST_ID_STATEMENT_PARAM_NAME in statement_params
213+
):
214+
request_id = statement_params[REQUEST_ID_STATEMENT_PARAM_NAME]
215+
216+
if not is_uuid4(request_id):
217+
# uuid.UUID will throw an error if invalid, but we explicitly check and throw here.
218+
raise ValueError(f"requestId {request_id} is not a valid UUID4.")
219+
self._request_id = uuid.UUID(str(request_id), version=4)
220+
221+
# Create a (deep copy) and remove the statement param, there is no need to encode it as extra parameter
222+
# one more time.
223+
statement_params = statement_params.copy()
224+
statement_params.pop(REQUEST_ID_STATEMENT_PARAM_NAME)
225+
else:
226+
# Generate UUID for query.
227+
self._request_id = uuid.uuid4()
206228

207229
logger.debug(f"Request id: {self._request_id}")
208230

test/integ/aio/test_cursor_async.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import pickle
1414
import time
15+
import uuid
1516
from datetime import date, datetime, timezone
1617
from typing import NamedTuple
1718
from unittest import mock
@@ -1920,3 +1921,37 @@ async def test_fetch_download_timeout_setting(conn_cnx):
19201921
sql = "SELECT seq4(), uniform(1, 10, RANDOM(12)) FROM TABLE(GENERATOR(ROWCOUNT => 100000)) v"
19211922
async with conn_cnx() as con, con.cursor() as cur:
19221923
assert len(await (await cur.execute(sql)).fetchall()) == 100000
1924+
1925+
1926+
@pytest.mark.parametrize(
1927+
"request_id",
1928+
[
1929+
"THIS IS NOT VALID",
1930+
uuid.uuid1(),
1931+
uuid.uuid3(uuid.NAMESPACE_URL, "www.snowflake.com"),
1932+
uuid.uuid5(uuid.NAMESPACE_URL, "www.snowflake.com"),
1933+
],
1934+
)
1935+
async def test_custom_request_id_negative(request_id, conn_cnx):
1936+
1937+
# Ensure that invalid request_ids (non uuid4) do not compromise interface.
1938+
with pytest.raises(ValueError, match="requestId"):
1939+
async with conn_cnx() as con:
1940+
async with con.cursor() as cur:
1941+
await cur.execute(
1942+
"select seq4() as foo from table(generator(rowcount=>5))",
1943+
_statement_params={"requestId": request_id},
1944+
)
1945+
1946+
1947+
async def test_custom_request_id(conn_cnx):
1948+
request_id = uuid.uuid4()
1949+
1950+
async with conn_cnx() as con:
1951+
async with con.cursor() as cur:
1952+
await cur.execute(
1953+
"select seq4() as foo from table(generator(rowcount=>5))",
1954+
_statement_params={"requestId": request_id},
1955+
)
1956+
1957+
assert cur._sfqid is not None, "Query must execute successfully."

0 commit comments

Comments
 (0)