Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- Added support for intermediate certificates as roots when they are stored in the trust store
- Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5`
- Dropped support for OpenSSL versions older than 1.1.1
- Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class`
- Constrained the types of `fetchone`, `fetchmany`, `fetchall`
- As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both.

- v3.17.3(September 02,2025)
- Enhanced configuration file permission warning messages.
Expand Down
23 changes: 18 additions & 5 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import sys
import traceback
import typing
import uuid
import warnings
import weakref
Expand All @@ -20,7 +21,16 @@
from logging import getLogger
from threading import Lock
from types import TracebackType
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
from typing import (
Any,
Callable,
Generator,
Iterable,
Iterator,
NamedTuple,
Sequence,
TypeVar,
)
from uuid import UUID

from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -76,7 +86,7 @@
QueryStatus,
)
from .converter import SnowflakeConverter
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
from .description import (
CLIENT_NAME,
CLIENT_VERSION,
Expand Down Expand Up @@ -125,6 +135,11 @@
from .util_text import construct_hostname, parse_account, split_statements
from .wif_util import AttestationProvider

if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
else:
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)

DEFAULT_CLIENT_PREFETCH_THREADS = 4
MAX_CLIENT_PREFETCH_THREADS = 10
MAX_CLIENT_FETCH_THREADS = 1024
Expand Down Expand Up @@ -1060,9 +1075,7 @@ def rollback(self) -> None:
"""Rolls back the current transaction."""
self.cursor().execute("ROLLBACK")

def cursor(
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
) -> SnowflakeCursor:
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
"""Creates a cursor object. Each statement will be executed in a new cursor object."""
logger.debug("cursor")
if not self.rest:
Expand Down
130 changes: 88 additions & 42 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
from __future__ import annotations

import abc
import collections
import logging
import os
Expand All @@ -19,12 +20,16 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterator,
Literal,
NamedTuple,
NoReturn,
Sequence,
Tuple,
TypeVar,
Union,
overload,
)

Expand Down Expand Up @@ -86,6 +91,7 @@
from .result_batch import ResultBatch

T = TypeVar("T", bound=collections.abc.Sequence)
FetchRow = TypeVar("FetchRow", bound=Union[Tuple[Any, ...], Dict[str, Any]])

logger = getLogger(__name__)

Expand Down Expand Up @@ -332,29 +338,7 @@ class ResultState(Enum):
RESET = 3


class SnowflakeCursor:
"""Implementation of Cursor object that is returned from Connection.cursor() method.

Attributes:
description: A list of namedtuples about metadata for all columns.
rowcount: The number of records updated or selected. If not clear, -1 is returned.
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
determined.
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
sqlstate: Snowflake SQL State code.
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
date_output_format: Snowflake output format for dates.
time_output_format: Snowflake output format for times.
timezone: Snowflake timezone.
binary_output_format: Snowflake output format for binary fields.
arraysize: The default number of rows fetched by fetchmany.
connection: The connection object by which the cursor was created.
errorhandle: The class that handles error handling.
is_file_transfer: Whether, or not the current command is a put, or get.
"""
class SnowflakeCursorBase(abc.ABC, Generic[FetchRow]):

# TODO:
# Most of these attributes have no reason to be properties, we could just store them in public variables.
Expand Down Expand Up @@ -382,13 +366,11 @@ def get_file_transfer_type(sql: str) -> FileTransferType | None:
def __init__(
self,
connection: SnowflakeConnection,
use_dict_result: bool = False,
) -> None:
"""Inits a SnowflakeCursor with a connection.

Args:
connection: The connection that created this cursor.
use_dict_result: Decides whether to use dict result or not.
"""
self._connection: SnowflakeConnection = connection

Expand Down Expand Up @@ -423,7 +405,6 @@ def __init__(
self._result: Iterator[tuple] | Iterator[dict] | None = None
self._result_set: ResultSet | None = None
self._result_state: ResultState = ResultState.DEFAULT
self._use_dict_result = use_dict_result
self.query: str | None = None
# TODO: self._query_result_format could be defined as an enum
self._query_result_format: str | None = None
Expand All @@ -435,7 +416,7 @@ def __init__(
self._first_chunk_time = None

self._log_max_query_length = connection.log_max_query_length
self._inner_cursor: SnowflakeCursor | None = None
self._inner_cursor: SnowflakeCursorBase | None = None
self._prefetch_hook = None
self._rownumber: int | None = None

Expand All @@ -448,6 +429,12 @@ def __del__(self) -> None: # pragma: no cover
if logger.getEffectiveLevel() <= logging.INFO:
logger.info(e)

@property
@abc.abstractmethod
def _use_dict_result(self) -> bool:
"""Decides whether results from helper functions are returned as a dict."""
pass

@property
def description(self) -> list[ResultMetadata]:
if self._description is None:
Expand Down Expand Up @@ -1514,8 +1501,17 @@ def executemany(

return self

def fetchone(self) -> dict | tuple | None:
"""Fetches one row."""
@abc.abstractmethod
def fetchone(self) -> FetchRow:
pass

def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None:
"""
Fetches one row.

Returns a dict if self._use_dict_result is True, otherwise
returns tuple.
"""
if self._prefetch_hook is not None:
self._prefetch_hook()
if self._result is None and self._result_set is not None:
Expand All @@ -1539,7 +1535,7 @@ def fetchone(self) -> dict | tuple | None:
else:
return None

def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
def fetchmany(self, size: int | None = None) -> list[FetchRow]:
"""Fetches the number of specified rows."""
if size is None:
size = self.arraysize
Expand All @@ -1565,7 +1561,7 @@ def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:

return ret

def fetchall(self) -> list[tuple] | list[dict]:
def fetchall(self) -> list[FetchRow]:
"""Fetches all of the results."""
ret = []
while True:
Expand Down Expand Up @@ -1728,20 +1724,31 @@ def wait_until_ready() -> None:
# Unset this function, so that we don't block anymore
self._prefetch_hook = None

if (
self._inner_cursor._total_rowcount == 1
and self._inner_cursor.fetchall()
== [("Multiple statements executed successfully.",)]
if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt(
self._inner_cursor.fetchall()
):
url = f"/queries/{sfqid}/result"
ret = self._connection.rest.request(url=url, method="get")
if "data" in ret and "resultIds" in ret["data"]:
self._init_multi_statement_results(ret["data"])

def _is_successful_multi_stmt(rows: list[Any]) -> bool:
if len(rows) != 1:
return False
row = rows[0]
if isinstance(row, tuple):
return row == ("Multiple statements executed successfully.",)
elif isinstance(row, dict):
return row == {
"multiple statement execution": "Multiple statements executed successfully."
}
else:
return False

self.connection.get_query_status_throw_if_error(
sfqid
) # Trigger an exception if query failed
self._inner_cursor = SnowflakeCursor(self.connection)
self._inner_cursor = self.__class__(self.connection)
self._sfqid = sfqid
self._prefetch_hook = wait_until_ready

Expand Down Expand Up @@ -1925,14 +1932,53 @@ def _create_file_transfer_agent(
)


class DictCursor(SnowflakeCursor):
class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]):
"""Implementation of Cursor object that is returned from Connection.cursor() method.

Attributes:
description: A list of namedtuples about metadata for all columns.
rowcount: The number of records updated or selected. If not clear, -1 is returned.
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
determined.
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
sqlstate: Snowflake SQL State code.
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
date_output_format: Snowflake output format for dates.
time_output_format: Snowflake output format for times.
timezone: Snowflake timezone.
binary_output_format: Snowflake output format for binary fields.
arraysize: The default number of rows fetched by fetchmany.
connection: The connection object by which the cursor was created.
errorhandle: The class that handles error handling.
is_file_transfer: Whether, or not the current command is a put, or get.
"""

@property
def _use_dict_result(self) -> bool:
return False

def fetchone(self) -> tuple[Any, ...] | None:
row = self._fetchone()
if not (row is None or isinstance(row, tuple)):
raise TypeError(f"fetchone got unexpected result: {row}")
return row


class DictCursor(SnowflakeCursorBase[dict[str, Any]]):
"""Cursor returning results in a dictionary."""

def __init__(self, connection) -> None:
super().__init__(
connection,
use_dict_result=True,
)
@property
def _use_dict_result(self) -> bool:
return True

def fetchone(self) -> dict[str, Any] | None:
row = self._fetchone()
if not (row is None or isinstance(row, dict)):
raise TypeError(f"fetchone got unexpected result: {row}")
return row


def __getattr__(name):
Expand Down
10 changes: 7 additions & 3 deletions test/integ/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
QueryStatus = None


@pytest.mark.parametrize("cursor_class", [SnowflakeCursor, DictCursor])
def test_simple_async(conn_cnx, cursor_class):
@pytest.mark.parametrize(
"cursor_class, row_type", [(SnowflakeCursor, tuple), (DictCursor, dict)]
)
def test_simple_async(conn_cnx, cursor_class, row_type):
"""Simple test to that shows the most simple usage of fire and forget.

This test also makes sure that wait_until_ready function's sleeping is tested and
Expand All @@ -29,7 +31,9 @@ def test_simple_async(conn_cnx, cursor_class):
with con.cursor(cursor_class) as cur:
cur.execute_async("select count(*) from table(generator(timeLimit => 5))")
cur.get_results_from_sfqid(cur.sfqid)
assert len(cur.fetchall()) == 1
rows = cur.fetchall()
assert len(rows) == 1
assert isinstance(rows[0], row_type)
assert cur.rowcount
assert cur.description

Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def mock_is_closed(*args, **kwargs):
cursor.execute("", _dataframe_ast="ABCD")


@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursor__cancel_query")
@patch("snowflake.connector.cursor.SnowflakeCursor._SnowflakeCursorBase__cancel_query")
def test_cursor_execute_timeout(mockCancelQuery):
def mock_cmd_query(*args, **kwargs):
time.sleep(10)
Expand Down
Loading