diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 3e02e72660..d463bca9e3 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -10,6 +10,19 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - v3.18.0(TBD) - Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only - Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once + - Added the `oauth_credentials_in_body` parameter supporting an option to send the oauth client credentials in the request body + - Fix retry behavior for `ECONNRESET` error + - Added an option to exclude `botocore` and `boto3` dependencies by setting `SNOWFLAKE_NO_BOTO` environment variable during installation + - Revert changing exception type in case of token expired scenario for `Oauth` authenticator back to `DatabaseError` + - Added support for pandas conversion for Day-time and Year-Month Interval types + +- v3.17.4(September 22,2025) + - Added support for intermediate certificates as roots when they are stored in the trust store + - Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5` + - Dropped support for OpenSSL versions older than 1.1.1 + - Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class` + - Constrained the types of `fetchone`, `fetchmany`, `fetchall` + - As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both. - v3.17.3(September 02,2025) - Enhanced configuration file permission warning messages. diff --git a/src/snowflake/connector/aio/_connection.py b/src/snowflake/connector/aio/_connection.py index 479af373ad..879b6f774d 100644 --- a/src/snowflake/connector/aio/_connection.py +++ b/src/snowflake/connector/aio/_connection.py @@ -7,13 +7,14 @@ import os import pathlib import sys +import typing import uuid import warnings from contextlib import suppress from io import StringIO from logging import getLogger from types import TracebackType -from typing import Any, AsyncIterator, Iterable +from typing import Any, AsyncIterator, Iterable, TypeVar from snowflake.connector import ( DatabaseError, @@ -73,7 +74,7 @@ from ..time_util import get_time_millis from ..util_text import split_statements from ..wif_util import AttestationProvider -from ._cursor import SnowflakeCursor +from ._cursor import SnowflakeCursor, SnowflakeCursorBase from ._description import CLIENT_NAME from ._direct_file_operation_utils import FileOperationParser, StreamDownloader from ._network import SnowflakeRestful @@ -107,6 +108,11 @@ DEFAULT_CONFIGURATION = copy.deepcopy(DEFAULT_CONFIGURATION_SYNC) DEFAULT_CONFIGURATION["application"] = (CLIENT_NAME, (type(None), str)) +if sys.version_info >= (3, 13) or typing.TYPE_CHECKING: + CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor) +else: + CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase) + class SnowflakeConnection(SnowflakeConnectionSync): OCSP_ENV_LOCK = asyncio.Lock() @@ -1031,9 +1037,7 @@ async def connect(self, **kwargs) -> None: self._telemetry = TelemetryClient(self._rest) await self._log_telemetry_imported_packages() - def cursor( - self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor - ) -> SnowflakeCursor: + def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls: logger.debug("cursor") if not self.rest: Error.errorhandler_wrapper( diff --git a/src/snowflake/connector/aio/_cursor.py b/src/snowflake/connector/aio/_cursor.py index 24a9b5da03..f7868b2ecc 100644 --- a/src/snowflake/connector/aio/_cursor.py +++ b/src/snowflake/connector/aio/_cursor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import asyncio import collections import logging @@ -39,10 +40,11 @@ ASYNC_NO_DATA_MAX_RETRY, ASYNC_RETRY_PATTERN, DESC_TABLE_RE, + ResultMetadata, + ResultMetadataV2, + ResultState, ) -from snowflake.connector.cursor import DictCursor as DictCursorSync -from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState -from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync +from snowflake.connector.cursor import SnowflakeCursorBase as SnowflakeCursorBaseSync from snowflake.connector.cursor import T from snowflake.connector.errorcode import ( ER_CURSOR_IS_CLOSED, @@ -66,17 +68,20 @@ logger = getLogger(__name__) +FetchRow = typing.TypeVar( + "FetchRow", bound=typing.Union[typing.Tuple[Any, ...], typing.Dict[str, Any]] +) + -class SnowflakeCursor(SnowflakeCursorSync): +class SnowflakeCursorBase(SnowflakeCursorBaseSync, abc.ABC, typing.Generic[FetchRow]): def __init__( self, connection: SnowflakeConnection, - use_dict_result: bool = False, ): - super().__init__(connection, use_dict_result) + super().__init__(connection) # the following fixes type hint self._connection = typing.cast("SnowflakeConnection", self._connection) - self._inner_cursor: SnowflakeCursor | None = None + self._inner_cursor: SnowflakeCursorBase | None = None self._lock_canceling = asyncio.Lock() self._timebomb: asyncio.Task | None = None self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None @@ -900,8 +905,17 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]: return None return [meta._to_result_metadata_v1() for meta in self._description] - async def fetchone(self) -> dict | tuple | None: - """Fetches one row.""" + @abc.abstractmethod + async def fetchone(self) -> FetchRow: + pass + + async def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None: + """ + Fetches one row. + + Returns a dict if self._use_dict_result is True, otherwise + returns tuple. + """ if self._prefetch_hook is not None: await self._prefetch_hook() if self._result is None and self._result_set is not None: @@ -926,7 +940,7 @@ async def fetchone(self) -> dict | tuple | None: else: return None - async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]: + async def fetchmany(self, size: int | None = None) -> list[FetchRow]: """Fetches the number of specified rows.""" if size is None: size = self.arraysize @@ -1266,20 +1280,31 @@ async def wait_until_ready() -> None: # Unset this function, so that we don't block anymore self._prefetch_hook = None - if ( - self._inner_cursor._total_rowcount == 1 - and await self._inner_cursor.fetchall() - == [("Multiple statements executed successfully.",)] + if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt( + await self._inner_cursor.fetchall() ): url = f"/queries/{sfqid}/result" ret = await self._connection.rest.request(url=url, method="get") if "data" in ret and "resultIds" in ret["data"]: await self._init_multi_statement_results(ret["data"]) + def _is_successful_multi_stmt(rows: list[Any]) -> bool: + if len(rows) != 1: + return False + row = rows[0] + if isinstance(row, tuple): + return row == ("Multiple statements executed successfully.",) + elif isinstance(row, dict): + return row == { + "multiple statement execution": "Multiple statements executed successfully." + } + else: + return False + await self.connection.get_query_status_throw_if_error( sfqid ) # Trigger an exception if query failed - self._inner_cursor = SnowflakeCursor(self.connection) + self._inner_cursor = self.__class__(self.connection) self._sfqid = sfqid self._prefetch_hook = wait_until_ready @@ -1321,5 +1346,50 @@ async def query_result(self, qid: str) -> SnowflakeCursor: return self -class DictCursor(DictCursorSync, SnowflakeCursor): - pass +class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]): + """Implementation of Cursor object that is returned from Connection.cursor() method. + + Attributes: + description: A list of namedtuples about metadata for all columns. + rowcount: The number of records updated or selected. If not clear, -1 is returned. + rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be + determined. + sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support. + sqlstate: Snowflake SQL State code. + timestamp_output_format: Snowflake timestamp_output_format for timestamps. + timestamp_ltz_output_format: Snowflake output format for LTZ timestamps. + timestamp_tz_output_format: Snowflake output format for TZ timestamps. + timestamp_ntz_output_format: Snowflake output format for NTZ timestamps. + date_output_format: Snowflake output format for dates. + time_output_format: Snowflake output format for times. + timezone: Snowflake timezone. + binary_output_format: Snowflake output format for binary fields. + arraysize: The default number of rows fetched by fetchmany. + connection: The connection object by which the cursor was created. + errorhandle: The class that handles error handling. + is_file_transfer: Whether, or not the current command is a put, or get. + """ + + @property + def _use_dict_result(self) -> bool: + return False + + async def fetchone(self) -> tuple[Any, ...] | None: + row = await self._fetchone() + if not (row is None or isinstance(row, tuple)): + raise TypeError(f"fetchone got unexpected result: {row}") + return row + + +class DictCursor(SnowflakeCursorBase[dict[str, Any]]): + """Cursor returning results in a dictionary.""" + + @property + def _use_dict_result(self) -> bool: + return True + + def fetchone(self) -> dict[str, Any] | None: + row = self._fetchone() + if not (row is None or isinstance(row, dict)): + raise TypeError(f"fetchone got unexpected result: {row}") + return row diff --git a/src/snowflake/connector/aio/auth/_auth.py b/src/snowflake/connector/aio/auth/_auth.py index b8c6564837..870252ecd7 100644 --- a/src/snowflake/connector/aio/auth/_auth.py +++ b/src/snowflake/connector/aio/auth/_auth.py @@ -62,6 +62,8 @@ async def authenticate( # max time waiting for MFA response, currently unused timeout: int | None = None, ) -> dict[str, str | int | bool]: + from . import AuthByOAuth + if mfa_callback or password_callback: # TODO: SNOW-1707210 for mfa_callback and password_callback support raise NotImplementedError( @@ -285,7 +287,11 @@ async def post_request_wrapper(self, url, headers, body) -> None: sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, ) ) - elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE: + elif (errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE) and ( + # SNOW-2329031: OAuth v1.0 does not support token renewal, + # for backward compatibility, we do not raise an exception here + not isinstance(auth_instance, AuthByOAuth) + ): raise ReauthenticationRequest( ProgrammingError( msg=ret["message"], diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 000907d5c4..64b73f8aef 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -8,6 +8,7 @@ import re import sys import traceback +import typing import uuid import warnings import weakref @@ -20,7 +21,16 @@ from logging import getLogger from threading import Lock from types import TracebackType -from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence +from typing import ( + Any, + Callable, + Generator, + Iterable, + Iterator, + NamedTuple, + Sequence, + TypeVar, +) from uuid import UUID from cryptography.hazmat.backends import default_backend @@ -76,7 +86,7 @@ QueryStatus, ) from .converter import SnowflakeConverter -from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor +from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase from .description import ( CLIENT_NAME, CLIENT_VERSION, @@ -125,6 +135,11 @@ from .util_text import construct_hostname, parse_account, split_statements from .wif_util import AttestationProvider +if sys.version_info >= (3, 13) or typing.TYPE_CHECKING: + CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor) +else: + CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase) + DEFAULT_CLIENT_PREFETCH_THREADS = 4 MAX_CLIENT_PREFETCH_THREADS = 10 MAX_CLIENT_FETCH_THREADS = 1024 @@ -1050,9 +1065,7 @@ def rollback(self) -> None: """Rolls back the current transaction.""" self.cursor().execute("ROLLBACK") - def cursor( - self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor - ) -> SnowflakeCursor: + def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls: """Creates a cursor object. Each statement will be executed in a new cursor object.""" logger.debug("cursor") if not self.rest: diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index a8ec738986..6738c486cb 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -1,6 +1,7 @@ #!/usr/bin/env python from __future__ import annotations +import abc import collections import logging import os @@ -19,12 +20,16 @@ TYPE_CHECKING, Any, Callable, + Dict, + Generic, Iterator, Literal, NamedTuple, NoReturn, Sequence, + Tuple, TypeVar, + Union, overload, ) @@ -83,6 +88,7 @@ from .result_batch import ResultBatch T = TypeVar("T", bound=collections.abc.Sequence) +FetchRow = TypeVar("FetchRow", bound=Union[Tuple[Any, ...], Dict[str, Any]]) logger = getLogger(__name__) @@ -329,29 +335,7 @@ class ResultState(Enum): RESET = 3 -class SnowflakeCursor: - """Implementation of Cursor object that is returned from Connection.cursor() method. - - Attributes: - description: A list of namedtuples about metadata for all columns. - rowcount: The number of records updated or selected. If not clear, -1 is returned. - rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be - determined. - sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support. - sqlstate: Snowflake SQL State code. - timestamp_output_format: Snowflake timestamp_output_format for timestamps. - timestamp_ltz_output_format: Snowflake output format for LTZ timestamps. - timestamp_tz_output_format: Snowflake output format for TZ timestamps. - timestamp_ntz_output_format: Snowflake output format for NTZ timestamps. - date_output_format: Snowflake output format for dates. - time_output_format: Snowflake output format for times. - timezone: Snowflake timezone. - binary_output_format: Snowflake output format for binary fields. - arraysize: The default number of rows fetched by fetchmany. - connection: The connection object by which the cursor was created. - errorhandle: The class that handles error handling. - is_file_transfer: Whether, or not the current command is a put, or get. - """ +class SnowflakeCursorBase(abc.ABC, Generic[FetchRow]): # TODO: # Most of these attributes have no reason to be properties, we could just store them in public variables. @@ -379,13 +363,11 @@ def get_file_transfer_type(sql: str) -> FileTransferType | None: def __init__( self, connection: SnowflakeConnection, - use_dict_result: bool = False, ) -> None: """Inits a SnowflakeCursor with a connection. Args: connection: The connection that created this cursor. - use_dict_result: Decides whether to use dict result or not. """ self._connection: SnowflakeConnection = connection @@ -420,7 +402,6 @@ def __init__( self._result: Iterator[tuple] | Iterator[dict] | None = None self._result_set: ResultSet | None = None self._result_state: ResultState = ResultState.DEFAULT - self._use_dict_result = use_dict_result self.query: str | None = None # TODO: self._query_result_format could be defined as an enum self._query_result_format: str | None = None @@ -432,7 +413,7 @@ def __init__( self._first_chunk_time = None self._log_max_query_length = connection.log_max_query_length - self._inner_cursor: SnowflakeCursor | None = None + self._inner_cursor: SnowflakeCursorBase | None = None self._prefetch_hook = None self._rownumber: int | None = None @@ -445,6 +426,12 @@ def __del__(self) -> None: # pragma: no cover if logger.getEffectiveLevel() <= logging.INFO: logger.info(e) + @property + @abc.abstractmethod + def _use_dict_result(self) -> bool: + """Decides whether results from helper functions are returned as a dict.""" + pass + @property def description(self) -> list[ResultMetadata]: if self._description is None: @@ -1522,8 +1509,17 @@ def executemany( return self - def fetchone(self) -> dict | tuple | None: - """Fetches one row.""" + @abc.abstractmethod + def fetchone(self) -> FetchRow: + pass + + def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None: + """ + Fetches one row. + + Returns a dict if self._use_dict_result is True, otherwise + returns tuple. + """ if self._prefetch_hook is not None: self._prefetch_hook() if self._result is None and self._result_set is not None: @@ -1547,7 +1543,7 @@ def fetchone(self) -> dict | tuple | None: else: return None - def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]: + def fetchmany(self, size: int | None = None) -> list[FetchRow]: """Fetches the number of specified rows.""" if size is None: size = self.arraysize @@ -1573,7 +1569,7 @@ def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]: return ret - def fetchall(self) -> list[tuple] | list[dict]: + def fetchall(self) -> list[FetchRow]: """Fetches all of the results.""" ret = [] while True: @@ -1736,20 +1732,31 @@ def wait_until_ready() -> None: # Unset this function, so that we don't block anymore self._prefetch_hook = None - if ( - self._inner_cursor._total_rowcount == 1 - and self._inner_cursor.fetchall() - == [("Multiple statements executed successfully.",)] + if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt( + self._inner_cursor.fetchall() ): url = f"/queries/{sfqid}/result" ret = self._connection.rest.request(url=url, method="get") if "data" in ret and "resultIds" in ret["data"]: self._init_multi_statement_results(ret["data"]) + def _is_successful_multi_stmt(rows: list[Any]) -> bool: + if len(rows) != 1: + return False + row = rows[0] + if isinstance(row, tuple): + return row == ("Multiple statements executed successfully.",) + elif isinstance(row, dict): + return row == { + "multiple statement execution": "Multiple statements executed successfully." + } + else: + return False + self.connection.get_query_status_throw_if_error( sfqid ) # Trigger an exception if query failed - self._inner_cursor = SnowflakeCursor(self.connection) + self._inner_cursor = self.__class__(self.connection) self._sfqid = sfqid self._prefetch_hook = wait_until_ready @@ -1929,14 +1936,53 @@ def _upload_stream( self._init_result_and_meta(file_transfer_agent.result()) -class DictCursor(SnowflakeCursor): +class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]): + """Implementation of Cursor object that is returned from Connection.cursor() method. + + Attributes: + description: A list of namedtuples about metadata for all columns. + rowcount: The number of records updated or selected. If not clear, -1 is returned. + rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be + determined. + sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support. + sqlstate: Snowflake SQL State code. + timestamp_output_format: Snowflake timestamp_output_format for timestamps. + timestamp_ltz_output_format: Snowflake output format for LTZ timestamps. + timestamp_tz_output_format: Snowflake output format for TZ timestamps. + timestamp_ntz_output_format: Snowflake output format for NTZ timestamps. + date_output_format: Snowflake output format for dates. + time_output_format: Snowflake output format for times. + timezone: Snowflake timezone. + binary_output_format: Snowflake output format for binary fields. + arraysize: The default number of rows fetched by fetchmany. + connection: The connection object by which the cursor was created. + errorhandle: The class that handles error handling. + is_file_transfer: Whether, or not the current command is a put, or get. + """ + + @property + def _use_dict_result(self) -> bool: + return False + + def fetchone(self) -> tuple[Any, ...] | None: + row = self._fetchone() + if not (row is None or isinstance(row, tuple)): + raise TypeError(f"fetchone got unexpected result: {row}") + return row + + +class DictCursor(SnowflakeCursorBase[dict[str, Any]]): """Cursor returning results in a dictionary.""" - def __init__(self, connection) -> None: - super().__init__( - connection, - use_dict_result=True, - ) + @property + def _use_dict_result(self) -> bool: + return True + + def fetchone(self) -> dict[str, Any] | None: + row = self._fetchone() + if not (row is None or isinstance(row, dict)): + raise TypeError(f"fetchone got unexpected result: {row}") + return row def __getattr__(name): diff --git a/test/integ/test_async.py b/test/integ/test_async.py index eec0861f13..a9d3afdf45 100644 --- a/test/integ/test_async.py +++ b/test/integ/test_async.py @@ -18,8 +18,10 @@ QueryStatus = None -@pytest.mark.parametrize("cursor_class", [SnowflakeCursor, DictCursor]) -def test_simple_async(conn_cnx, cursor_class): +@pytest.mark.parametrize( + "cursor_class, row_type", [(SnowflakeCursor, tuple), (DictCursor, dict)] +) +def test_simple_async(conn_cnx, cursor_class, row_type): """Simple test to that shows the most simple usage of fire and forget. This test also makes sure that wait_until_ready function's sleeping is tested and @@ -29,7 +31,9 @@ def test_simple_async(conn_cnx, cursor_class): with con.cursor(cursor_class) as cur: cur.execute_async("select count(*) from table(generator(timeLimit => 5))") cur.get_results_from_sfqid(cur.sfqid) - assert len(cur.fetchall()) == 1 + rows = cur.fetchall() + assert len(rows) == 1 + assert isinstance(rows[0], row_type) assert cur.rowcount assert cur.description diff --git a/test/unit/aio/test_cursor_async_unit.py b/test/unit/aio/test_cursor_async_unit.py index 39894c3bad..9927c7a240 100644 --- a/test/unit/aio/test_cursor_async_unit.py +++ b/test/unit/aio/test_cursor_async_unit.py @@ -79,7 +79,9 @@ def mock_is_closed(*args, **kwargs): await cursor.execute("", _dataframe_ast="ABCD") -@patch("snowflake.connector.aio._cursor.SnowflakeCursor._SnowflakeCursor__cancel_query") +@patch( + "snowflake.connector.aio._cursor.SnowflakeCursor._SnowflakeCursorBase__cancel_query" +) async def test_cursor_execute_timeout(mockCancelQuery): async def mock_cmd_query(*args, **kwargs): await asyncio.sleep(10) diff --git a/test/unit/test_cursor.py b/test/unit/test_cursor.py index 6970e6acfb..528d4ba469 100644 --- a/test/unit/test_cursor.py +++ b/test/unit/test_cursor.py @@ -75,7 +75,7 @@ def mock_is_closed(*args, **kwargs): cursor.execute("", _dataframe_ast="ABCD") -@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursor__cancel_query") +@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursorBase__cancel_query") def test_cursor_execute_timeout(mockCancelQuery): def mock_cmd_query(*args, **kwargs): time.sleep(10)