Skip to content

Commit 35c03cc

Browse files
SNOW-2333702 Fix types for DictCursor (#2532)
Co-authored-by: Patryk Czajka <[email protected]>
1 parent f8ec13a commit 35c03cc

File tree

5 files changed

+117
-51
lines changed

5 files changed

+117
-51
lines changed

DESCRIPTION.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
2020
- Added support for intermediate certificates as roots when they are stored in the trust store
2121
- Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5`
2222
- Dropped support for OpenSSL versions older than 1.1.1
23+
- Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class`
24+
- Constrained the types of `fetchone`, `fetchmany`, `fetchall`
25+
- As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both.
2326

2427
- v3.17.3(September 02,2025)
2528
- Enhanced configuration file permission warning messages.

src/snowflake/connector/connection.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import sys
1010
import traceback
11+
import typing
1112
import uuid
1213
import warnings
1314
import weakref
@@ -20,7 +21,16 @@
2021
from logging import getLogger
2122
from threading import Lock
2223
from types import TracebackType
23-
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
24+
from typing import (
25+
Any,
26+
Callable,
27+
Generator,
28+
Iterable,
29+
Iterator,
30+
NamedTuple,
31+
Sequence,
32+
TypeVar,
33+
)
2434
from uuid import UUID
2535

2636
from cryptography.hazmat.backends import default_backend
@@ -76,7 +86,7 @@
7686
QueryStatus,
7787
)
7888
from .converter import SnowflakeConverter
79-
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor
89+
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
8090
from .description import (
8191
CLIENT_NAME,
8292
CLIENT_VERSION,
@@ -125,6 +135,11 @@
125135
from .util_text import construct_hostname, parse_account, split_statements
126136
from .wif_util import AttestationProvider
127137

138+
if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
139+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
140+
else:
141+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)
142+
128143
DEFAULT_CLIENT_PREFETCH_THREADS = 4
129144
MAX_CLIENT_PREFETCH_THREADS = 10
130145
MAX_CLIENT_FETCH_THREADS = 1024
@@ -1060,9 +1075,7 @@ def rollback(self) -> None:
10601075
"""Rolls back the current transaction."""
10611076
self.cursor().execute("ROLLBACK")
10621077

1063-
def cursor(
1064-
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
1065-
) -> SnowflakeCursor:
1078+
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
10661079
"""Creates a cursor object. Each statement will be executed in a new cursor object."""
10671080
logger.debug("cursor")
10681081
if not self.rest:

src/snowflake/connector/cursor.py

Lines changed: 88 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22
from __future__ import annotations
33

4+
import abc
45
import collections
56
import logging
67
import os
@@ -19,12 +20,16 @@
1920
TYPE_CHECKING,
2021
Any,
2122
Callable,
23+
Dict,
24+
Generic,
2225
Iterator,
2326
Literal,
2427
NamedTuple,
2528
NoReturn,
2629
Sequence,
30+
Tuple,
2731
TypeVar,
32+
Union,
2833
overload,
2934
)
3035

@@ -86,6 +91,7 @@
8691
from .result_batch import ResultBatch
8792

8893
T = TypeVar("T", bound=collections.abc.Sequence)
94+
FetchRow = TypeVar("FetchRow", bound=Union[Tuple[Any, ...], Dict[str, Any]])
8995

9096
logger = getLogger(__name__)
9197

@@ -332,29 +338,7 @@ class ResultState(Enum):
332338
RESET = 3
333339

334340

335-
class SnowflakeCursor:
336-
"""Implementation of Cursor object that is returned from Connection.cursor() method.
337-
338-
Attributes:
339-
description: A list of namedtuples about metadata for all columns.
340-
rowcount: The number of records updated or selected. If not clear, -1 is returned.
341-
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
342-
determined.
343-
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
344-
sqlstate: Snowflake SQL State code.
345-
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
346-
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
347-
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
348-
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
349-
date_output_format: Snowflake output format for dates.
350-
time_output_format: Snowflake output format for times.
351-
timezone: Snowflake timezone.
352-
binary_output_format: Snowflake output format for binary fields.
353-
arraysize: The default number of rows fetched by fetchmany.
354-
connection: The connection object by which the cursor was created.
355-
errorhandle: The class that handles error handling.
356-
is_file_transfer: Whether, or not the current command is a put, or get.
357-
"""
341+
class SnowflakeCursorBase(abc.ABC, Generic[FetchRow]):
358342

359343
# TODO:
360344
# Most of these attributes have no reason to be properties, we could just store them in public variables.
@@ -382,13 +366,11 @@ def get_file_transfer_type(sql: str) -> FileTransferType | None:
382366
def __init__(
383367
self,
384368
connection: SnowflakeConnection,
385-
use_dict_result: bool = False,
386369
) -> None:
387370
"""Inits a SnowflakeCursor with a connection.
388371
389372
Args:
390373
connection: The connection that created this cursor.
391-
use_dict_result: Decides whether to use dict result or not.
392374
"""
393375
self._connection: SnowflakeConnection = connection
394376

@@ -423,7 +405,6 @@ def __init__(
423405
self._result: Iterator[tuple] | Iterator[dict] | None = None
424406
self._result_set: ResultSet | None = None
425407
self._result_state: ResultState = ResultState.DEFAULT
426-
self._use_dict_result = use_dict_result
427408
self.query: str | None = None
428409
# TODO: self._query_result_format could be defined as an enum
429410
self._query_result_format: str | None = None
@@ -435,7 +416,7 @@ def __init__(
435416
self._first_chunk_time = None
436417

437418
self._log_max_query_length = connection.log_max_query_length
438-
self._inner_cursor: SnowflakeCursor | None = None
419+
self._inner_cursor: SnowflakeCursorBase | None = None
439420
self._prefetch_hook = None
440421
self._rownumber: int | None = None
441422

@@ -448,6 +429,12 @@ def __del__(self) -> None: # pragma: no cover
448429
if logger.getEffectiveLevel() <= logging.INFO:
449430
logger.info(e)
450431

432+
@property
433+
@abc.abstractmethod
434+
def _use_dict_result(self) -> bool:
435+
"""Decides whether results from helper functions are returned as a dict."""
436+
pass
437+
451438
@property
452439
def description(self) -> list[ResultMetadata]:
453440
if self._description is None:
@@ -1514,8 +1501,17 @@ def executemany(
15141501

15151502
return self
15161503

1517-
def fetchone(self) -> dict | tuple | None:
1518-
"""Fetches one row."""
1504+
@abc.abstractmethod
1505+
def fetchone(self) -> FetchRow:
1506+
pass
1507+
1508+
def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None:
1509+
"""
1510+
Fetches one row.
1511+
1512+
Returns a dict if self._use_dict_result is True, otherwise
1513+
returns tuple.
1514+
"""
15191515
if self._prefetch_hook is not None:
15201516
self._prefetch_hook()
15211517
if self._result is None and self._result_set is not None:
@@ -1539,7 +1535,7 @@ def fetchone(self) -> dict | tuple | None:
15391535
else:
15401536
return None
15411537

1542-
def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
1538+
def fetchmany(self, size: int | None = None) -> list[FetchRow]:
15431539
"""Fetches the number of specified rows."""
15441540
if size is None:
15451541
size = self.arraysize
@@ -1565,7 +1561,7 @@ def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
15651561

15661562
return ret
15671563

1568-
def fetchall(self) -> list[tuple] | list[dict]:
1564+
def fetchall(self) -> list[FetchRow]:
15691565
"""Fetches all of the results."""
15701566
ret = []
15711567
while True:
@@ -1728,20 +1724,31 @@ def wait_until_ready() -> None:
17281724
# Unset this function, so that we don't block anymore
17291725
self._prefetch_hook = None
17301726

1731-
if (
1732-
self._inner_cursor._total_rowcount == 1
1733-
and self._inner_cursor.fetchall()
1734-
== [("Multiple statements executed successfully.",)]
1727+
if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt(
1728+
self._inner_cursor.fetchall()
17351729
):
17361730
url = f"/queries/{sfqid}/result"
17371731
ret = self._connection.rest.request(url=url, method="get")
17381732
if "data" in ret and "resultIds" in ret["data"]:
17391733
self._init_multi_statement_results(ret["data"])
17401734

1735+
def _is_successful_multi_stmt(rows: list[Any]) -> bool:
1736+
if len(rows) != 1:
1737+
return False
1738+
row = rows[0]
1739+
if isinstance(row, tuple):
1740+
return row == ("Multiple statements executed successfully.",)
1741+
elif isinstance(row, dict):
1742+
return row == {
1743+
"multiple statement execution": "Multiple statements executed successfully."
1744+
}
1745+
else:
1746+
return False
1747+
17411748
self.connection.get_query_status_throw_if_error(
17421749
sfqid
17431750
) # Trigger an exception if query failed
1744-
self._inner_cursor = SnowflakeCursor(self.connection)
1751+
self._inner_cursor = self.__class__(self.connection)
17451752
self._sfqid = sfqid
17461753
self._prefetch_hook = wait_until_ready
17471754

@@ -1925,14 +1932,53 @@ def _create_file_transfer_agent(
19251932
)
19261933

19271934

1928-
class DictCursor(SnowflakeCursor):
1935+
class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]):
1936+
"""Implementation of Cursor object that is returned from Connection.cursor() method.
1937+
1938+
Attributes:
1939+
description: A list of namedtuples about metadata for all columns.
1940+
rowcount: The number of records updated or selected. If not clear, -1 is returned.
1941+
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
1942+
determined.
1943+
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
1944+
sqlstate: Snowflake SQL State code.
1945+
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
1946+
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
1947+
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
1948+
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
1949+
date_output_format: Snowflake output format for dates.
1950+
time_output_format: Snowflake output format for times.
1951+
timezone: Snowflake timezone.
1952+
binary_output_format: Snowflake output format for binary fields.
1953+
arraysize: The default number of rows fetched by fetchmany.
1954+
connection: The connection object by which the cursor was created.
1955+
errorhandle: The class that handles error handling.
1956+
is_file_transfer: Whether, or not the current command is a put, or get.
1957+
"""
1958+
1959+
@property
1960+
def _use_dict_result(self) -> bool:
1961+
return False
1962+
1963+
def fetchone(self) -> tuple[Any, ...] | None:
1964+
row = self._fetchone()
1965+
if not (row is None or isinstance(row, tuple)):
1966+
raise TypeError(f"fetchone got unexpected result: {row}")
1967+
return row
1968+
1969+
1970+
class DictCursor(SnowflakeCursorBase[dict[str, Any]]):
19291971
"""Cursor returning results in a dictionary."""
19301972

1931-
def __init__(self, connection) -> None:
1932-
super().__init__(
1933-
connection,
1934-
use_dict_result=True,
1935-
)
1973+
@property
1974+
def _use_dict_result(self) -> bool:
1975+
return True
1976+
1977+
def fetchone(self) -> dict[str, Any] | None:
1978+
row = self._fetchone()
1979+
if not (row is None or isinstance(row, dict)):
1980+
raise TypeError(f"fetchone got unexpected result: {row}")
1981+
return row
19361982

19371983

19381984
def __getattr__(name):

test/integ/test_async.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
QueryStatus = None
1919

2020

21-
@pytest.mark.parametrize("cursor_class", [SnowflakeCursor, DictCursor])
22-
def test_simple_async(conn_cnx, cursor_class):
21+
@pytest.mark.parametrize(
22+
"cursor_class, row_type", [(SnowflakeCursor, tuple), (DictCursor, dict)]
23+
)
24+
def test_simple_async(conn_cnx, cursor_class, row_type):
2325
"""Simple test to that shows the most simple usage of fire and forget.
2426
2527
This test also makes sure that wait_until_ready function's sleeping is tested and
@@ -29,7 +31,9 @@ def test_simple_async(conn_cnx, cursor_class):
2931
with con.cursor(cursor_class) as cur:
3032
cur.execute_async("select count(*) from table(generator(timeLimit => 5))")
3133
cur.get_results_from_sfqid(cur.sfqid)
32-
assert len(cur.fetchall()) == 1
34+
rows = cur.fetchall()
35+
assert len(rows) == 1
36+
assert isinstance(rows[0], row_type)
3337
assert cur.rowcount
3438
assert cur.description
3539

test/unit/test_cursor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def mock_is_closed(*args, **kwargs):
7878
cursor.execute("", _dataframe_ast="ABCD")
7979

8080

81-
@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursor__cancel_query")
81+
@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursorBase__cancel_query")
8282
def test_cursor_execute_timeout(mockCancelQuery):
8383
def mock_cmd_query(*args, **kwargs):
8484
time.sleep(10)

0 commit comments

Comments
 (0)