Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
- v3.18.0(TBD)
- Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only
- Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once
- Added the `oauth_credentials_in_body` parameter supporting an option to send the oauth client credentials in the request body
- Fix retry behavior for `ECONNRESET` error
- Added an option to exclude `botocore` and `boto3` dependencies by setting `SNOWFLAKE_NO_BOTO` environment variable during installation
- Revert changing exception type in case of token expired scenario for `Oauth` authenticator back to `DatabaseError`
- Added support for pandas conversion for Day-time and Year-Month Interval types

- v3.17.4(September 22,2025)
- Added support for intermediate certificates as roots when they are stored in the trust store
- Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5`
- Dropped support for OpenSSL versions older than 1.1.1
- Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class`
- Constrained the types of `fetchone`, `fetchmany`, `fetchall`
- As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both.

- v3.17.3(September 02,2025)
- Enhanced configuration file permission warning messages.
Expand Down
14 changes: 9 additions & 5 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import os
import pathlib
import sys
import typing
import uuid
import warnings
from contextlib import suppress
from io import StringIO
from logging import getLogger
from types import TracebackType
from typing import Any, AsyncIterator, Iterable
from typing import Any, AsyncIterator, Iterable, TypeVar

from snowflake.connector import (
DatabaseError,
Expand Down Expand Up @@ -73,7 +74,7 @@
from ..time_util import get_time_millis
from ..util_text import split_statements
from ..wif_util import AttestationProvider
from ._cursor import SnowflakeCursor
from ._cursor import SnowflakeCursor, SnowflakeCursorBase
from ._description import CLIENT_NAME
from ._direct_file_operation_utils import FileOperationParser, StreamDownloader
from ._network import SnowflakeRestful
Expand Down Expand Up @@ -107,6 +108,11 @@
DEFAULT_CONFIGURATION = copy.deepcopy(DEFAULT_CONFIGURATION_SYNC)
DEFAULT_CONFIGURATION["application"] = (CLIENT_NAME, (type(None), str))

if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
else:
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)


class SnowflakeConnection(SnowflakeConnectionSync):
OCSP_ENV_LOCK = asyncio.Lock()
Expand Down Expand Up @@ -1031,9 +1037,7 @@ async def connect(self, **kwargs) -> None:
self._telemetry = TelemetryClient(self._rest)
await self._log_telemetry_imported_packages()

def cursor(
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
) -> SnowflakeCursor:
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
logger.debug("cursor")
if not self.rest:
Error.errorhandler_wrapper(
Expand Down
104 changes: 87 additions & 17 deletions src/snowflake/connector/aio/_cursor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
import asyncio
import collections
import logging
Expand Down Expand Up @@ -39,10 +40,11 @@
ASYNC_NO_DATA_MAX_RETRY,
ASYNC_RETRY_PATTERN,
DESC_TABLE_RE,
ResultMetadata,
ResultMetadataV2,
ResultState,
)
from snowflake.connector.cursor import DictCursor as DictCursorSync
from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState
from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync
from snowflake.connector.cursor import SnowflakeCursorBase as SnowflakeCursorBaseSync
from snowflake.connector.cursor import T
from snowflake.connector.errorcode import (
ER_CURSOR_IS_CLOSED,
Expand All @@ -66,17 +68,20 @@

logger = getLogger(__name__)

FetchRow = typing.TypeVar(
"FetchRow", bound=typing.Union[typing.Tuple[Any, ...], typing.Dict[str, Any]]
)


class SnowflakeCursor(SnowflakeCursorSync):
class SnowflakeCursorBase(SnowflakeCursorBaseSync, abc.ABC, typing.Generic[FetchRow]):
def __init__(
self,
connection: SnowflakeConnection,
use_dict_result: bool = False,
):
super().__init__(connection, use_dict_result)
super().__init__(connection)
# the following fixes type hint
self._connection = typing.cast("SnowflakeConnection", self._connection)
self._inner_cursor: SnowflakeCursor | None = None
self._inner_cursor: SnowflakeCursorBase | None = None
self._lock_canceling = asyncio.Lock()
self._timebomb: asyncio.Task | None = None
self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None
Expand Down Expand Up @@ -900,8 +905,17 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
return None
return [meta._to_result_metadata_v1() for meta in self._description]

async def fetchone(self) -> dict | tuple | None:
"""Fetches one row."""
@abc.abstractmethod
async def fetchone(self) -> FetchRow:
pass

async def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None:
"""
Fetches one row.

Returns a dict if self._use_dict_result is True, otherwise
returns tuple.
"""
if self._prefetch_hook is not None:
await self._prefetch_hook()
if self._result is None and self._result_set is not None:
Expand All @@ -926,7 +940,7 @@ async def fetchone(self) -> dict | tuple | None:
else:
return None

async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
async def fetchmany(self, size: int | None = None) -> list[FetchRow]:
"""Fetches the number of specified rows."""
if size is None:
size = self.arraysize
Expand Down Expand Up @@ -1266,20 +1280,31 @@ async def wait_until_ready() -> None:
# Unset this function, so that we don't block anymore
self._prefetch_hook = None

if (
self._inner_cursor._total_rowcount == 1
and await self._inner_cursor.fetchall()
== [("Multiple statements executed successfully.",)]
if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt(
await self._inner_cursor.fetchall()
):
url = f"/queries/{sfqid}/result"
ret = await self._connection.rest.request(url=url, method="get")
if "data" in ret and "resultIds" in ret["data"]:
await self._init_multi_statement_results(ret["data"])

def _is_successful_multi_stmt(rows: list[Any]) -> bool:
if len(rows) != 1:
return False
row = rows[0]
if isinstance(row, tuple):
return row == ("Multiple statements executed successfully.",)
elif isinstance(row, dict):
return row == {
"multiple statement execution": "Multiple statements executed successfully."
}
else:
return False

await self.connection.get_query_status_throw_if_error(
sfqid
) # Trigger an exception if query failed
self._inner_cursor = SnowflakeCursor(self.connection)
self._inner_cursor = self.__class__(self.connection)
self._sfqid = sfqid
self._prefetch_hook = wait_until_ready

Expand Down Expand Up @@ -1321,5 +1346,50 @@ async def query_result(self, qid: str) -> SnowflakeCursor:
return self


class DictCursor(DictCursorSync, SnowflakeCursor):
pass
class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]):
"""Implementation of Cursor object that is returned from Connection.cursor() method.

Attributes:
description: A list of namedtuples about metadata for all columns.
rowcount: The number of records updated or selected. If not clear, -1 is returned.
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
determined.
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
sqlstate: Snowflake SQL State code.
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
date_output_format: Snowflake output format for dates.
time_output_format: Snowflake output format for times.
timezone: Snowflake timezone.
binary_output_format: Snowflake output format for binary fields.
arraysize: The default number of rows fetched by fetchmany.
connection: The connection object by which the cursor was created.
errorhandle: The class that handles error handling.
is_file_transfer: Whether, or not the current command is a put, or get.
"""

@property
def _use_dict_result(self) -> bool:
return False

async def fetchone(self) -> tuple[Any, ...] | None:
row = await self._fetchone()
if not (row is None or isinstance(row, tuple)):
raise TypeError(f"fetchone got unexpected result: {row}")
return row


class DictCursor(SnowflakeCursorBase[dict[str, Any]]):
"""Cursor returning results in a dictionary."""

@property
def _use_dict_result(self) -> bool:
return True

def fetchone(self) -> dict[str, Any] | None:
row = self._fetchone()
if not (row is None or isinstance(row, dict)):
raise TypeError(f"fetchone got unexpected result: {row}")
return row
8 changes: 7 additions & 1 deletion src/snowflake/connector/aio/auth/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ async def authenticate(
# max time waiting for MFA response, currently unused
timeout: int | None = None,
) -> dict[str, str | int | bool]:
from . import AuthByOAuth

if mfa_callback or password_callback:
# TODO: SNOW-1707210 for mfa_callback and password_callback support
raise NotImplementedError(
Expand Down Expand Up @@ -285,7 +287,11 @@ async def post_request_wrapper(self, url, headers, body) -> None:
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
)
)
elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE:
elif (errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE) and (
# SNOW-2329031: OAuth v1.0 does not support token renewal,
# for backward compatibility, we do not raise an exception here
not isinstance(auth_instance, AuthByOAuth)
):
raise ReauthenticationRequest(
ProgrammingError(
msg=ret["message"],
Expand Down
23 changes: 18 additions & 5 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import sys
import traceback
import typing
import uuid
import warnings
import weakref
Expand All @@ -20,7 +21,16 @@
from logging import getLogger
from threading import Lock
from types import TracebackType
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
from typing import (
Any,
Callable,
Generator,
Iterable,
Iterator,
NamedTuple,
Sequence,
TypeVar,
)
from uuid import UUID

from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -76,7 +86,7 @@
QueryStatus,
)
from .converter import SnowflakeConverter
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
from .description import (
CLIENT_NAME,
CLIENT_VERSION,
Expand Down Expand Up @@ -125,6 +135,11 @@
from .util_text import construct_hostname, parse_account, split_statements
from .wif_util import AttestationProvider

if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
else:
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)

DEFAULT_CLIENT_PREFETCH_THREADS = 4
MAX_CLIENT_PREFETCH_THREADS = 10
MAX_CLIENT_FETCH_THREADS = 1024
Expand Down Expand Up @@ -1050,9 +1065,7 @@ def rollback(self) -> None:
"""Rolls back the current transaction."""
self.cursor().execute("ROLLBACK")

def cursor(
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
) -> SnowflakeCursor:
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
"""Creates a cursor object. Each statement will be executed in a new cursor object."""
logger.debug("cursor")
if not self.rest:
Expand Down
Loading
Loading