Skip to content

Commit 842692b

Browse files
sfc-gh-bchinnsfc-gh-pczajka
authored andcommitted
SNOW-2333702 Fix types for DictCursor (#2532)
Co-authored-by: Patryk Czajka <[email protected]>
1 parent 012eed6 commit 842692b

File tree

9 files changed

+233
-75
lines changed

9 files changed

+233
-75
lines changed

DESCRIPTION.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1010
- v3.18.0(TBD)
1111
- Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only
1212
- Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once
13+
- Added the `oauth_credentials_in_body` parameter supporting an option to send the oauth client credentials in the request body
14+
- Fix retry behavior for `ECONNRESET` error
15+
- Added an option to exclude `botocore` and `boto3` dependencies by setting `SNOWFLAKE_NO_BOTO` environment variable during installation
16+
- Revert changing exception type in case of token expired scenario for `Oauth` authenticator back to `DatabaseError`
17+
- Added support for pandas conversion for Day-time and Year-Month Interval types
18+
19+
- v3.17.4(September 22,2025)
20+
- Added support for intermediate certificates as roots when they are stored in the trust store
21+
- Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5`
22+
- Dropped support for OpenSSL versions older than 1.1.1
23+
- Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class`
24+
- Constrained the types of `fetchone`, `fetchmany`, `fetchall`
25+
- As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both.
1326

1427
- v3.17.3(September 02,2025)
1528
- Enhanced configuration file permission warning messages.

src/snowflake/connector/aio/_connection.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
import os
88
import pathlib
99
import sys
10+
import typing
1011
import uuid
1112
import warnings
1213
from contextlib import suppress
1314
from io import StringIO
1415
from logging import getLogger
1516
from types import TracebackType
16-
from typing import Any, AsyncIterator, Iterable
17+
from typing import Any, AsyncIterator, Iterable, TypeVar
1718

1819
from snowflake.connector import (
1920
DatabaseError,
@@ -73,7 +74,7 @@
7374
from ..time_util import get_time_millis
7475
from ..util_text import split_statements
7576
from ..wif_util import AttestationProvider
76-
from ._cursor import SnowflakeCursor
77+
from ._cursor import SnowflakeCursor, SnowflakeCursorBase
7778
from ._description import CLIENT_NAME
7879
from ._direct_file_operation_utils import FileOperationParser, StreamDownloader
7980
from ._network import SnowflakeRestful
@@ -107,6 +108,11 @@
107108
DEFAULT_CONFIGURATION = copy.deepcopy(DEFAULT_CONFIGURATION_SYNC)
108109
DEFAULT_CONFIGURATION["application"] = (CLIENT_NAME, (type(None), str))
109110

111+
if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
112+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
113+
else:
114+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)
115+
110116

111117
class SnowflakeConnection(SnowflakeConnectionSync):
112118
OCSP_ENV_LOCK = asyncio.Lock()
@@ -1031,9 +1037,7 @@ async def connect(self, **kwargs) -> None:
10311037
self._telemetry = TelemetryClient(self._rest)
10321038
await self._log_telemetry_imported_packages()
10331039

1034-
def cursor(
1035-
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
1036-
) -> SnowflakeCursor:
1040+
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
10371041
logger.debug("cursor")
10381042
if not self.rest:
10391043
Error.errorhandler_wrapper(

src/snowflake/connector/aio/_cursor.py

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import abc
34
import asyncio
45
import collections
56
import logging
@@ -39,10 +40,11 @@
3940
ASYNC_NO_DATA_MAX_RETRY,
4041
ASYNC_RETRY_PATTERN,
4142
DESC_TABLE_RE,
43+
ResultMetadata,
44+
ResultMetadataV2,
45+
ResultState,
4246
)
43-
from snowflake.connector.cursor import DictCursor as DictCursorSync
44-
from snowflake.connector.cursor import ResultMetadata, ResultMetadataV2, ResultState
45-
from snowflake.connector.cursor import SnowflakeCursor as SnowflakeCursorSync
47+
from snowflake.connector.cursor import SnowflakeCursorBase as SnowflakeCursorBaseSync
4648
from snowflake.connector.cursor import T
4749
from snowflake.connector.errorcode import (
4850
ER_CURSOR_IS_CLOSED,
@@ -66,17 +68,20 @@
6668

6769
logger = getLogger(__name__)
6870

71+
FetchRow = typing.TypeVar(
72+
"FetchRow", bound=typing.Union[typing.Tuple[Any, ...], typing.Dict[str, Any]]
73+
)
74+
6975

70-
class SnowflakeCursor(SnowflakeCursorSync):
76+
class SnowflakeCursorBase(SnowflakeCursorBaseSync, abc.ABC, typing.Generic[FetchRow]):
7177
def __init__(
7278
self,
7379
connection: SnowflakeConnection,
74-
use_dict_result: bool = False,
7580
):
76-
super().__init__(connection, use_dict_result)
81+
super().__init__(connection)
7782
# the following fixes type hint
7883
self._connection = typing.cast("SnowflakeConnection", self._connection)
79-
self._inner_cursor: SnowflakeCursor | None = None
84+
self._inner_cursor: SnowflakeCursorBase | None = None
8085
self._lock_canceling = asyncio.Lock()
8186
self._timebomb: asyncio.Task | None = None
8287
self._prefetch_hook: typing.Callable[[], typing.Awaitable] | None = None
@@ -900,8 +905,17 @@ async def describe(self, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
900905
return None
901906
return [meta._to_result_metadata_v1() for meta in self._description]
902907

903-
async def fetchone(self) -> dict | tuple | None:
904-
"""Fetches one row."""
908+
@abc.abstractmethod
909+
async def fetchone(self) -> FetchRow:
910+
pass
911+
912+
async def _fetchone(self) -> dict[str, Any] | tuple[Any, ...] | None:
913+
"""
914+
Fetches one row.
915+
916+
Returns a dict if self._use_dict_result is True, otherwise
917+
returns tuple.
918+
"""
905919
if self._prefetch_hook is not None:
906920
await self._prefetch_hook()
907921
if self._result is None and self._result_set is not None:
@@ -926,7 +940,7 @@ async def fetchone(self) -> dict | tuple | None:
926940
else:
927941
return None
928942

929-
async def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
943+
async def fetchmany(self, size: int | None = None) -> list[FetchRow]:
930944
"""Fetches the number of specified rows."""
931945
if size is None:
932946
size = self.arraysize
@@ -1266,20 +1280,31 @@ async def wait_until_ready() -> None:
12661280
# Unset this function, so that we don't block anymore
12671281
self._prefetch_hook = None
12681282

1269-
if (
1270-
self._inner_cursor._total_rowcount == 1
1271-
and await self._inner_cursor.fetchall()
1272-
== [("Multiple statements executed successfully.",)]
1283+
if self._inner_cursor._total_rowcount == 1 and _is_successful_multi_stmt(
1284+
await self._inner_cursor.fetchall()
12731285
):
12741286
url = f"/queries/{sfqid}/result"
12751287
ret = await self._connection.rest.request(url=url, method="get")
12761288
if "data" in ret and "resultIds" in ret["data"]:
12771289
await self._init_multi_statement_results(ret["data"])
12781290

1291+
def _is_successful_multi_stmt(rows: list[Any]) -> bool:
1292+
if len(rows) != 1:
1293+
return False
1294+
row = rows[0]
1295+
if isinstance(row, tuple):
1296+
return row == ("Multiple statements executed successfully.",)
1297+
elif isinstance(row, dict):
1298+
return row == {
1299+
"multiple statement execution": "Multiple statements executed successfully."
1300+
}
1301+
else:
1302+
return False
1303+
12791304
await self.connection.get_query_status_throw_if_error(
12801305
sfqid
12811306
) # Trigger an exception if query failed
1282-
self._inner_cursor = SnowflakeCursor(self.connection)
1307+
self._inner_cursor = self.__class__(self.connection)
12831308
self._sfqid = sfqid
12841309
self._prefetch_hook = wait_until_ready
12851310

@@ -1321,5 +1346,50 @@ async def query_result(self, qid: str) -> SnowflakeCursor:
13211346
return self
13221347

13231348

1324-
class DictCursor(DictCursorSync, SnowflakeCursor):
1325-
pass
1349+
class SnowflakeCursor(SnowflakeCursorBase[tuple[Any, ...]]):
1350+
"""Implementation of Cursor object that is returned from Connection.cursor() method.
1351+
1352+
Attributes:
1353+
description: A list of namedtuples about metadata for all columns.
1354+
rowcount: The number of records updated or selected. If not clear, -1 is returned.
1355+
rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
1356+
determined.
1357+
sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
1358+
sqlstate: Snowflake SQL State code.
1359+
timestamp_output_format: Snowflake timestamp_output_format for timestamps.
1360+
timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
1361+
timestamp_tz_output_format: Snowflake output format for TZ timestamps.
1362+
timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
1363+
date_output_format: Snowflake output format for dates.
1364+
time_output_format: Snowflake output format for times.
1365+
timezone: Snowflake timezone.
1366+
binary_output_format: Snowflake output format for binary fields.
1367+
arraysize: The default number of rows fetched by fetchmany.
1368+
connection: The connection object by which the cursor was created.
1369+
errorhandle: The class that handles error handling.
1370+
is_file_transfer: Whether, or not the current command is a put, or get.
1371+
"""
1372+
1373+
@property
1374+
def _use_dict_result(self) -> bool:
1375+
return False
1376+
1377+
async def fetchone(self) -> tuple[Any, ...] | None:
1378+
row = await self._fetchone()
1379+
if not (row is None or isinstance(row, tuple)):
1380+
raise TypeError(f"fetchone got unexpected result: {row}")
1381+
return row
1382+
1383+
1384+
class DictCursor(SnowflakeCursorBase[dict[str, Any]]):
1385+
"""Cursor returning results in a dictionary."""
1386+
1387+
@property
1388+
def _use_dict_result(self) -> bool:
1389+
return True
1390+
1391+
def fetchone(self) -> dict[str, Any] | None:
1392+
row = self._fetchone()
1393+
if not (row is None or isinstance(row, dict)):
1394+
raise TypeError(f"fetchone got unexpected result: {row}")
1395+
return row

src/snowflake/connector/aio/auth/_auth.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ async def authenticate(
6262
# max time waiting for MFA response, currently unused
6363
timeout: int | None = None,
6464
) -> dict[str, str | int | bool]:
65+
from . import AuthByOAuth
66+
6567
if mfa_callback or password_callback:
6668
# TODO: SNOW-1707210 for mfa_callback and password_callback support
6769
raise NotImplementedError(
@@ -285,7 +287,11 @@ async def post_request_wrapper(self, url, headers, body) -> None:
285287
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
286288
)
287289
)
288-
elif errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE:
290+
elif (errno == OAUTH_ACCESS_TOKEN_EXPIRED_GS_CODE) and (
291+
# SNOW-2329031: OAuth v1.0 does not support token renewal,
292+
# for backward compatibility, we do not raise an exception here
293+
not isinstance(auth_instance, AuthByOAuth)
294+
):
289295
raise ReauthenticationRequest(
290296
ProgrammingError(
291297
msg=ret["message"],

src/snowflake/connector/connection.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import sys
1010
import traceback
11+
import typing
1112
import uuid
1213
import warnings
1314
import weakref
@@ -20,7 +21,16 @@
2021
from logging import getLogger
2122
from threading import Lock
2223
from types import TracebackType
23-
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
24+
from typing import (
25+
Any,
26+
Callable,
27+
Generator,
28+
Iterable,
29+
Iterator,
30+
NamedTuple,
31+
Sequence,
32+
TypeVar,
33+
)
2434
from uuid import UUID
2535

2636
from cryptography.hazmat.backends import default_backend
@@ -76,7 +86,7 @@
7686
QueryStatus,
7787
)
7888
from .converter import SnowflakeConverter
79-
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor
89+
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
8090
from .description import (
8191
CLIENT_NAME,
8292
CLIENT_VERSION,
@@ -125,6 +135,11 @@
125135
from .util_text import construct_hostname, parse_account, split_statements
126136
from .wif_util import AttestationProvider
127137

138+
if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
139+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
140+
else:
141+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)
142+
128143
DEFAULT_CLIENT_PREFETCH_THREADS = 4
129144
MAX_CLIENT_PREFETCH_THREADS = 10
130145
MAX_CLIENT_FETCH_THREADS = 1024
@@ -1050,9 +1065,7 @@ def rollback(self) -> None:
10501065
"""Rolls back the current transaction."""
10511066
self.cursor().execute("ROLLBACK")
10521067

1053-
def cursor(
1054-
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
1055-
) -> SnowflakeCursor:
1068+
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
10561069
"""Creates a cursor object. Each statement will be executed in a new cursor object."""
10571070
logger.debug("cursor")
10581071
if not self.rest:

0 commit comments

Comments
 (0)