Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- Added support for intermediate certificates as roots when they are stored in the trust store
- Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5`
- Dropped support for OpenSSL versions older than 1.1.1
- Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class`
- Constrained the types of `fetchone`, `fetchmany`, `fetchall`
- As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both.

- v3.17.3(September 02,2025)
- Enhanced configuration file permission warning messages.
Expand Down
23 changes: 18 additions & 5 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import sys
import traceback
import typing
import uuid
import warnings
import weakref
Expand All @@ -20,7 +21,16 @@
from logging import getLogger
from threading import Lock
from types import TracebackType
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
from typing import (
Any,
Callable,
Generator,
Iterable,
Iterator,
NamedTuple,
Sequence,
TypeVar,
)
from uuid import UUID

from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -76,7 +86,7 @@
QueryStatus,
)
from .converter import SnowflakeConverter
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
from .description import (
CLIENT_NAME,
CLIENT_VERSION,
Expand Down Expand Up @@ -125,6 +135,11 @@
from .util_text import construct_hostname, parse_account, split_statements
from .wif_util import AttestationProvider

if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
else:
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)

DEFAULT_CLIENT_PREFETCH_THREADS = 4
MAX_CLIENT_PREFETCH_THREADS = 10
MAX_CLIENT_FETCH_THREADS = 1024
Expand Down Expand Up @@ -1060,9 +1075,7 @@ def rollback(self) -> None:
"""Rolls back the current transaction."""
self.cursor().execute("ROLLBACK")

def cursor(
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
) -> SnowflakeCursor:
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
"""Creates a cursor object. Each statement will be executed in a new cursor object."""
logger.debug("cursor")
if not self.rest:
Expand Down
128 changes: 86 additions & 42 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
from __future__ import annotations

import abc
import collections
import logging
import os
Expand All @@ -19,12 +20,16 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterator,
Literal,
NamedTuple,
NoReturn,
Sequence,
Tuple,
TypeVar,
Union,
overload,
)

Expand Down Expand Up @@ -86,6 +91,7 @@
from .result_batch import ResultBatch

T = TypeVar("T", bound=collections.abc.Sequence)
FetchRow = TypeVar("FetchRow", bound=Union[Tuple[Any, ...], Dict[str, Any]])

logger = getLogger(__name__)

Expand Down Expand Up @@ -332,29 +338,7 @@ class ResultState(Enum):
RESET = 3


class SnowflakeCursor:
"""Implementation of Cursor object that is returned from Connection.cursor() method.

Attributes:
description: A list of namedtuples about metadata for all columns.
rowcount: The number of records updated or selected. If not clear, -1 is returned.
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
determined.
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
sqlstate: Snowflake SQL State code.
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
date_output_format: Snowflake output format for dates.
time_output_format: Snowflake output format for times.
timezone: Snowflake timezone.
binary_output_format: Snowflake output format for binary fields.
arraysize: The default number of rows fetched by fetchmany.
connection: The connection object by which the cursor was created.
errorhandle: The class that handles error handling.
is_file_transfer: Whether, or not the current command is a put, or get.
"""
class SnowflakeCursorBase(abc.ABC, Generic[FetchRow]):

# TODO:
# Most of these attributes have no reason to be properties, we could just store them in public variables.
Expand Down Expand Up @@ -382,13 +366,11 @@ def get_file_transfer_type(sql: str) -> FileTransferType | None:
def __init__(
self,
connection: SnowflakeConnection,
use_dict_result: bool = False,
) -> None:
"""Inits a SnowflakeCursor with a connection.

Args:
connection: The connection that created this cursor.
use_dict_result: Decides whether to use dict result or not.
"""
self._connection: SnowflakeConnection = connection

Expand Down Expand Up @@ -423,7 +405,6 @@ def __init__(
self._result: Iterator[tuple] | Iterator[dict] | None = None
self._result_set: ResultSet | None = None
self._result_state: ResultState = ResultState.DEFAULT
self._use_dict_result = use_dict_result
self.query: str | None = None
# TODO: self._query_result_format could be defined as an enum
self._query_result_format: str | None = None
Expand All @@ -435,7 +416,7 @@ def __init__(
self._first_chunk_time = None

self._log_max_query_length = connection.log_max_query_length
self._inner_cursor: SnowflakeCursor | None = None
self._inner_cursor: SnowflakeCursorBase | None = None
self._prefetch_hook = None
self._rownumber: int | None = None

Expand All @@ -448,6 +429,10 @@ def __del__(self) -> None: # pragma: no cover
if logger.getEffectiveLevel() <= logging.INFO:
logger.info(e)

@property
@abc.abstractmethod
def _use_dict_result(self) -> bool: ...

@property
def description(self) -> list[ResultMetadata]:
if self._description is None:
Expand Down Expand Up @@ -1514,8 +1499,17 @@ def executemany(

return self

def fetchone(self) -> dict | tuple | None:
"""Fetches one row."""
@abc.abstractmethod
def fetchone(self) -> FetchRow:
pass

def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None:
"""
Fetches one row.

Returns a dict if self._use_dict_result is True, otherwise
returns tuple.
"""
if self._prefetch_hook is not None:
self._prefetch_hook()
if self._result is None and self._result_set is not None:
Expand All @@ -1539,7 +1533,7 @@ def fetchone(self) -> dict | tuple | None:
else:
return None

def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
def fetchmany(self, size: int | None = None) -> list[FetchRow]:
"""Fetches the number of specified rows."""
if size is None:
size = self.arraysize
Expand All @@ -1565,7 +1559,7 @@ def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:

return ret

def fetchall(self) -> list[tuple] | list[dict]:
def fetchall(self) -> list[FetchRow]:
"""Fetches all of the results."""
ret = []
while True:
Expand Down Expand Up @@ -1728,20 +1722,31 @@ def wait_until_ready() -> None:
# Unset this function, so that we don't block anymore
self._prefetch_hook = None

if (
self._inner_cursor._total_rowcount == 1
and self._inner_cursor.fetchall()
== [("Multiple statements executed successfully.",)]
if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt(
self._inner_cursor.fetchall()
):
url = f"/queries/{sfqid}/result"
ret = self._connection.rest.request(url=url, method="get")
if "data" in ret and "resultIds" in ret["data"]:
self._init_multi_statement_results(ret["data"])

def _is_successful_multi_stmt(rows: list[Any]) -> bool:
if len(rows) != 1:
return False
row = rows[0]
if isinstance(row, tuple):
return row == ("Multiple statements executed successfully.",)
elif isinstance(row, dict):
return row == {
"multiple statement execution": "Multiple statements executed successfully."
}
else:
return False

self.connection.get_query_status_throw_if_error(
sfqid
) # Trigger an exception if query failed
self._inner_cursor = SnowflakeCursor(self.connection)
self._inner_cursor = self.__class__(self.connection)
self._sfqid = sfqid
self._prefetch_hook = wait_until_ready

Expand Down Expand Up @@ -1925,14 +1930,53 @@ def _create_file_transfer_agent(
)


class DictCursor(SnowflakeCursor):
class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]):
"""Implementation of Cursor object that is returned from Connection.cursor() method.

Attributes:
description: A list of namedtuples about metadata for all columns.
rowcount: The number of records updated or selected. If not clear, -1 is returned.
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
determined.
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
sqlstate: Snowflake SQL State code.
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
date_output_format: Snowflake output format for dates.
time_output_format: Snowflake output format for times.
timezone: Snowflake timezone.
binary_output_format: Snowflake output format for binary fields.
arraysize: The default number of rows fetched by fetchmany.
connection: The connection object by which the cursor was created.
errorhandle: The class that handles error handling.
is_file_transfer: Whether, or not the current command is a put, or get.
"""

@property
def _use_dict_result(self) -> bool:
return False

def fetchone(self) -> tuple[Any, ...] | None:
row = self._fetchone()
if not (row is None or isinstance(row, tuple)):
raise TypeError(f"fetchone got unexpected result: {row}")
return row


class DictCursor(SnowflakeCursorBase[dict[str, Any]]):
"""Cursor returning results in a dictionary."""

def __init__(self, connection) -> None:
super().__init__(
connection,
use_dict_result=True,
)
@property
def _use_dict_result(self) -> bool:
return True

def fetchone(self) -> dict[str, Any] | None:
row = self._fetchone()
if not (row is None or isinstance(row, dict)):
raise TypeError(f"fetchone got unexpected result: {row}")
return row


def __getattr__(name):
Expand Down
10 changes: 7 additions & 3 deletions test/integ/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
QueryStatus = None


@pytest.mark.parametrize("cursor_class", [SnowflakeCursor, DictCursor])
def test_simple_async(conn_cnx, cursor_class):
@pytest.mark.parametrize(
"cursor_class, row_type", [(SnowflakeCursor, tuple), (DictCursor, dict)]
)
def test_simple_async(conn_cnx, cursor_class, row_type):
"""Simple test to that shows the most simple usage of fire and forget.
This test also makes sure that wait_until_ready function's sleeping is tested and
Expand All @@ -29,7 +31,9 @@ def test_simple_async(conn_cnx, cursor_class):
with con.cursor(cursor_class) as cur:
cur.execute_async("select count(*) from table(generator(timeLimit => 5))")
cur.get_results_from_sfqid(cur.sfqid)
assert len(cur.fetchall()) == 1
rows = cur.fetchall()
assert len(rows) == 1
assert isinstance(rows[0], row_type)
assert cur.rowcount
assert cur.description

Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def mock_is_closed(*args, **kwargs):
cursor.execute("", _dataframe_ast="ABCD")


@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursor__cancel_query")
@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursorBase__cancel_query")
def test_cursor_execute_timeout(mockCancelQuery):
def mock_cmd_query(*args, **kwargs):
time.sleep(10)
Expand Down
Loading