Skip to content

Commit 3e1805e

Browse files
committed
SNOW-2333702 Fix return type of SnowflakeConnection.cursor()
1 parent 1028e4f commit 3e1805e

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1010
- v3.18.0(TBD)
1111
- Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only
1212
- Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once
13+
- Fixed the return type of `SnowflakeConnection.cursor(cursor_class)` to match the type of `cursor_class`
1314
- Constrained the types of `fetchone`, `fetchmany`, `fetchall`
1415
- As part of this fix, `DictCursor` is no longer a subclass of `SnowflakeCursor`; use `SnowflakeCursorBase` as a superclass of both.
1516

src/snowflake/connector/connection.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import sys
1010
import traceback
11+
import typing
1112
import uuid
1213
import warnings
1314
import weakref
@@ -20,7 +21,16 @@
2021
from logging import getLogger
2122
from threading import Lock
2223
from types import TracebackType
23-
from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence
24+
from typing import (
25+
Any,
26+
Callable,
27+
Generator,
28+
Iterable,
29+
Iterator,
30+
NamedTuple,
31+
Sequence,
32+
TypeVar,
33+
)
2434
from uuid import UUID
2535

2636
from cryptography.hazmat.backends import default_backend
@@ -76,7 +86,7 @@
7686
QueryStatus,
7787
)
7888
from .converter import SnowflakeConverter
79-
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor
89+
from .cursor import LOG_MAX_QUERY_LENGTH, SnowflakeCursor, SnowflakeCursorBase
8090
from .description import (
8191
CLIENT_NAME,
8292
CLIENT_VERSION,
@@ -125,6 +135,11 @@
125135
from .util_text import construct_hostname, parse_account, split_statements
126136
from .wif_util import AttestationProvider
127137

138+
if sys.version_info >= (3, 13) or typing.TYPE_CHECKING:
139+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase, default=SnowflakeCursor)
140+
else:
141+
CursorCls = TypeVar("CursorCls", bound=SnowflakeCursorBase)
142+
128143
DEFAULT_CLIENT_PREFETCH_THREADS = 4
129144
MAX_CLIENT_PREFETCH_THREADS = 10
130145
MAX_CLIENT_FETCH_THREADS = 1024
@@ -1055,9 +1070,7 @@ def rollback(self) -> None:
10551070
"""Rolls back the current transaction."""
10561071
self.cursor().execute("ROLLBACK")
10571072

1058-
def cursor(
1059-
self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor
1060-
) -> SnowflakeCursor:
1073+
def cursor(self, cursor_class: type[CursorCls] = SnowflakeCursor) -> CursorCls:
10611074
"""Creates a cursor object. Each statement will be executed in a new cursor object."""
10621075
logger.debug("cursor")
10631076
if not self.rest:

0 commit comments

Comments
 (0)