diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 40d4039ee8..d463bca9e3 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -20,6 +20,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - Added support for intermediate certificates as roots when they are stored in the trust store - Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5` - Dropped support for OpenSSL versions older than 1.1.1 + - Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class` + - Constrained the types of `fetchone`, `fetchmany`, `fetchall` + - As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both. - v3.17.3(September 02,2025) - Enhanced configuration file permission warning messages. diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index cfc592e331..c4efe25f8c 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -8,6 +8,7 @@ import re import sys import traceback +import typing import uuid import warnings import weakref @@ -20,7 +21,16 @@ from logging import getLogger from threading import Lock from types import TracebackType -from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence +from typing import ( + Any, + Callable, + Generator, + Iterable, + Iterator, + NamedTuple, + Sequence, + TypeVar, +) from uuid import UUID from cryptography.hazmat.backends import default_backend @@ -76,7 +86,7 @@ QueryStatus, ) from .converter import SnowflakeConverter -from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor +from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase from .description import ( CLIENT_NAME, CLIENT_VERSION, @@ -125,6 +135,11 @@ from .util_text import construct_hostname, parse_account, split_statements from .wif_util import AttestationProvider +if sys.version_info >= (3, 13) or typing.TYPE_CHECKING: + CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor) +else: + CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase) + DEFAULT_CLIENT_PREFETCH_THREADS = 4 MAX_CLIENT_PREFETCH_THREADS = 10 MAX_CLIENT_FETCH_THREADS = 1024 @@ -1060,9 +1075,7 @@ def rollback(self) -> None: """Rolls back the current transaction.""" self.cursor().execute("ROLLBACK") - def cursor( - self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor - ) -> SnowflakeCursor: + def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls: """Creates a cursor object. Each statement will be executed in a new cursor object.""" logger.debug("cursor") if not self.rest: diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index 6ade7f3d8e..c13ab242c7 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -1,6 +1,7 @@ #!/usr/bin/env python from __future__ import annotations +import abc import collections import logging import os @@ -19,12 +20,16 @@ TYPE_CHECKING, Any, Callable, + Dict, + Generic, Iterator, Literal, NamedTuple, NoReturn, Sequence, + Tuple, TypeVar, + Union, overload, ) @@ -86,6 +91,7 @@ from .result_batch import ResultBatch T = TypeVar("T", bound=collections.abc.Sequence) +FetchRow = TypeVar("FetchRow", bound=Union[Tuple[Any, ...], Dict[str, Any]]) logger = getLogger(__name__) @@ -332,29 +338,7 @@ class ResultState(Enum): RESET = 3 -class SnowflakeCursor: - """Implementation of Cursor object that is returned from Connection.cursor() method. - - Attributes: - description: A list of namedtuples about metadata for all columns. - rowcount: The number of records updated or selected. If not clear, -1 is returned. - rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be - determined. - sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support. - sqlstate: Snowflake SQL State code. - timestamp_output_format: Snowflake timestamp_output_format for timestamps. - timestamp_ltz_output_format: Snowflake output format for LTZ timestamps. - timestamp_tz_output_format: Snowflake output format for TZ timestamps. - timestamp_ntz_output_format: Snowflake output format for NTZ timestamps. - date_output_format: Snowflake output format for dates. - time_output_format: Snowflake output format for times. - timezone: Snowflake timezone. - binary_output_format: Snowflake output format for binary fields. - arraysize: The default number of rows fetched by fetchmany. - connection: The connection object by which the cursor was created. - errorhandle: The class that handles error handling. - is_file_transfer: Whether, or not the current command is a put, or get. - """ +class SnowflakeCursorBase(abc.ABC, Generic[FetchRow]): # TODO: # Most of these attributes have no reason to be properties, we could just store them in public variables. @@ -382,13 +366,11 @@ def get_file_transfer_type(sql: str) -> FileTransferType | None: def __init__( self, connection: SnowflakeConnection, - use_dict_result: bool = False, ) -> None: """Inits a SnowflakeCursor with a connection. Args: connection: The connection that created this cursor. - use_dict_result: Decides whether to use dict result or not. """ self._connection: SnowflakeConnection = connection @@ -423,7 +405,6 @@ def __init__( self._result: Iterator[tuple] | Iterator[dict] | None = None self._result_set: ResultSet | None = None self._result_state: ResultState = ResultState.DEFAULT - self._use_dict_result = use_dict_result self.query: str | None = None # TODO: self._query_result_format could be defined as an enum self._query_result_format: str | None = None @@ -435,7 +416,7 @@ def __init__( self._first_chunk_time = None self._log_max_query_length = connection.log_max_query_length - self._inner_cursor: SnowflakeCursor | None = None + self._inner_cursor: SnowflakeCursorBase | None = None self._prefetch_hook = None self._rownumber: int | None = None @@ -448,6 +429,12 @@ def __del__(self) -> None: # pragma: no cover if logger.getEffectiveLevel() <= logging.INFO: logger.info(e) + @property + @abc.abstractmethod + def _use_dict_result(self) -> bool: + """Decides whether results from helper functions are returned as a dict.""" + pass + @property def description(self) -> list[ResultMetadata]: if self._description is None: @@ -1514,8 +1501,17 @@ def executemany( return self - def fetchone(self) -> dict | tuple | None: - """Fetches one row.""" + @abc.abstractmethod + def fetchone(self) -> FetchRow: + pass + + def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None: + """ + Fetches one row. + + Returns a dict if self._use_dict_result is True, otherwise + returns tuple. + """ if self._prefetch_hook is not None: self._prefetch_hook() if self._result is None and self._result_set is not None: @@ -1539,7 +1535,7 @@ def fetchone(self) -> dict | tuple | None: else: return None - def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]: + def fetchmany(self, size: int | None = None) -> list[FetchRow]: """Fetches the number of specified rows.""" if size is None: size = self.arraysize @@ -1565,7 +1561,7 @@ def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]: return ret - def fetchall(self) -> list[tuple] | list[dict]: + def fetchall(self) -> list[FetchRow]: """Fetches all of the results.""" ret = [] while True: @@ -1728,20 +1724,31 @@ def wait_until_ready() -> None: # Unset this function, so that we don't block anymore self._prefetch_hook = None - if ( - self._inner_cursor._total_rowcount == 1 - and self._inner_cursor.fetchall() - == [("Multiple statements executed successfully.",)] + if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt( + self._inner_cursor.fetchall() ): url = f"/queries/{sfqid}/result" ret = self._connection.rest.request(url=url, method="get") if "data" in ret and "resultIds" in ret["data"]: self._init_multi_statement_results(ret["data"]) + def _is_successful_multi_stmt(rows: list[Any]) -> bool: + if len(rows) != 1: + return False + row = rows[0] + if isinstance(row, tuple): + return row == ("Multiple statements executed successfully.",) + elif isinstance(row, dict): + return row == { + "multiple statement execution": "Multiple statements executed successfully." + } + else: + return False + self.connection.get_query_status_throw_if_error( sfqid ) # Trigger an exception if query failed - self._inner_cursor = SnowflakeCursor(self.connection) + self._inner_cursor = self.__class__(self.connection) self._sfqid = sfqid self._prefetch_hook = wait_until_ready @@ -1925,14 +1932,53 @@ def _create_file_transfer_agent( ) -class DictCursor(SnowflakeCursor): +class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]): + """Implementation of Cursor object that is returned from Connection.cursor() method. + + Attributes: + description: A list of namedtuples about metadata for all columns. + rowcount: The number of records updated or selected. If not clear, -1 is returned. + rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be + determined. + sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support. + sqlstate: Snowflake SQL State code. + timestamp_output_format: Snowflake timestamp_output_format for timestamps. + timestamp_ltz_output_format: Snowflake output format for LTZ timestamps. + timestamp_tz_output_format: Snowflake output format for TZ timestamps. + timestamp_ntz_output_format: Snowflake output format for NTZ timestamps. + date_output_format: Snowflake output format for dates. + time_output_format: Snowflake output format for times. + timezone: Snowflake timezone. + binary_output_format: Snowflake output format for binary fields. + arraysize: The default number of rows fetched by fetchmany. + connection: The connection object by which the cursor was created. + errorhandle: The class that handles error handling. + is_file_transfer: Whether, or not the current command is a put, or get. + """ + + @property + def _use_dict_result(self) -> bool: + return False + + def fetchone(self) -> tuple[Any, ...] | None: + row = self._fetchone() + if not (row is None or isinstance(row, tuple)): + raise TypeError(f"fetchone got unexpected result: {row}") + return row + + +class DictCursor(SnowflakeCursorBase[dict[str, Any]]): """Cursor returning results in a dictionary.""" - def __init__(self, connection) -> None: - super().__init__( - connection, - use_dict_result=True, - ) + @property + def _use_dict_result(self) -> bool: + return True + + def fetchone(self) -> dict[str, Any] | None: + row = self._fetchone() + if not (row is None or isinstance(row, dict)): + raise TypeError(f"fetchone got unexpected result: {row}") + return row def __getattr__(name): diff --git a/test/integ/test_async.py b/test/integ/test_async.py index eec0861f13..a9d3afdf45 100644 --- a/test/integ/test_async.py +++ b/test/integ/test_async.py @@ -18,8 +18,10 @@ QueryStatus = None -@pytest.mark.parametrize("cursor_class", [SnowflakeCursor, DictCursor]) -def test_simple_async(conn_cnx, cursor_class): +@pytest.mark.parametrize( + "cursor_class, row_type", [(SnowflakeCursor, tuple), (DictCursor, dict)] +) +def test_simple_async(conn_cnx, cursor_class, row_type): """Simple test to that shows the most simple usage of fire and forget. This test also makes sure that wait_until_ready function's sleeping is tested and @@ -29,7 +31,9 @@ def test_simple_async(conn_cnx, cursor_class): with con.cursor(cursor_class) as cur: cur.execute_async("select count(*) from table(generator(timeLimit => 5))") cur.get_results_from_sfqid(cur.sfqid) - assert len(cur.fetchall()) == 1 + rows = cur.fetchall() + assert len(rows) == 1 + assert isinstance(rows[0], row_type) assert cur.rowcount assert cur.description diff --git a/test/unit/test_cursor.py b/test/unit/test_cursor.py index c936a3928e..205f65d387 100644 --- a/test/unit/test_cursor.py +++ b/test/unit/test_cursor.py @@ -78,7 +78,7 @@ def mock_is_closed(*args, **kwargs): cursor.execute("", _dataframe_ast="ABCD") -@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursor__cancel_query") +@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursorBase__cancel_query") def test_cursor_execute_timeout(mockCancelQuery): def mock_cmd_query(*args, **kwargs): time.sleep(10)