11#!/usr/bin/env python 
22from  __future__ import  annotations 
33
4+ import  abc 
45import  collections 
56import  logging 
67import  os 
1920    TYPE_CHECKING ,
2021    Any ,
2122    Callable ,
23+     Dict ,
24+     Generic ,
2225    Iterator ,
2326    Literal ,
2427    NamedTuple ,
2528    NoReturn ,
2629    Sequence ,
30+     Tuple ,
2731    TypeVar ,
32+     Union ,
2833    overload ,
2934)
3035
8691    from  .result_batch  import  ResultBatch 
8792
8893T  =  TypeVar ("T" , bound = collections .abc .Sequence )
94+ FetchRow  =  TypeVar ("FetchRow" , bound = Union [Tuple [Any , ...], Dict [str , Any ]])
8995
9096logger  =  getLogger (__name__ )
9197
@@ -332,29 +338,7 @@ class ResultState(Enum):
332338    RESET  =  3 
333339
334340
335- class  SnowflakeCursor :
336-     """Implementation of Cursor object that is returned from Connection.cursor() method. 
337- 
338-     Attributes: 
339-         description: A list of namedtuples about metadata for all columns. 
340-         rowcount: The number of records updated or selected. If not clear, -1 is returned. 
341-         rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be 
342-             determined. 
343-         sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support. 
344-         sqlstate: Snowflake SQL State code. 
345-         timestamp_output_format: Snowflake timestamp_output_format for timestamps. 
346-         timestamp_ltz_output_format: Snowflake output format for LTZ timestamps. 
347-         timestamp_tz_output_format: Snowflake output format for TZ timestamps. 
348-         timestamp_ntz_output_format: Snowflake output format for NTZ timestamps. 
349-         date_output_format: Snowflake output format for dates. 
350-         time_output_format: Snowflake output format for times. 
351-         timezone: Snowflake timezone. 
352-         binary_output_format: Snowflake output format for binary fields. 
353-         arraysize: The default number of rows fetched by fetchmany. 
354-         connection: The connection object by which the cursor was created. 
355-         errorhandle: The class that handles error handling. 
356-         is_file_transfer: Whether, or not the current command is a put, or get. 
357-     """ 
341+ class  SnowflakeCursorBase (abc .ABC , Generic [FetchRow ]):
358342
359343    # TODO: 
360344    #    Most of these attributes have no reason to be properties, we could just store them in public variables. 
@@ -382,13 +366,11 @@ def get_file_transfer_type(sql: str) -> FileTransferType | None:
382366    def  __init__ (
383367        self ,
384368        connection : SnowflakeConnection ,
385-         use_dict_result : bool  =  False ,
386369    ) ->  None :
387370        """Inits a SnowflakeCursor with a connection. 
388371
389372        Args: 
390373            connection: The connection that created this cursor. 
391-             use_dict_result: Decides whether to use dict result or not. 
392374        """ 
393375        self ._connection : SnowflakeConnection  =  connection 
394376
@@ -423,7 +405,6 @@ def __init__(
423405        self ._result : Iterator [tuple ] |  Iterator [dict ] |  None  =  None 
424406        self ._result_set : ResultSet  |  None  =  None 
425407        self ._result_state : ResultState  =  ResultState .DEFAULT 
426-         self ._use_dict_result  =  use_dict_result 
427408        self .query : str  |  None  =  None 
428409        # TODO: self._query_result_format could be defined as an enum 
429410        self ._query_result_format : str  |  None  =  None 
@@ -435,7 +416,7 @@ def __init__(
435416        self ._first_chunk_time  =  None 
436417
437418        self ._log_max_query_length  =  connection .log_max_query_length 
438-         self ._inner_cursor : SnowflakeCursor  |  None  =  None 
419+         self ._inner_cursor : SnowflakeCursorBase  |  None  =  None 
439420        self ._prefetch_hook  =  None 
440421        self ._rownumber : int  |  None  =  None 
441422
@@ -448,6 +429,12 @@ def __del__(self) -> None:  # pragma: no cover
448429            if  logger .getEffectiveLevel () <=  logging .INFO :
449430                logger .info (e )
450431
432+     @property  
433+     @abc .abstractmethod  
434+     def  _use_dict_result (self ) ->  bool :
435+         """Decides whether results from helper functions are returned as a dict.""" 
436+         pass 
437+ 
451438    @property  
452439    def  description (self ) ->  list [ResultMetadata ]:
453440        if  self ._description  is  None :
@@ -1514,8 +1501,17 @@ def executemany(
15141501
15151502        return  self 
15161503
1517-     def  fetchone (self ) ->  dict  |  tuple  |  None :
1518-         """Fetches one row.""" 
1504+     @abc .abstractmethod  
1505+     def  fetchone (self ) ->  FetchRow :
1506+         pass 
1507+ 
1508+     def  _fetchone (self ) ->  dict [str , Any ] |  tuple [Any , ...] |  None :
1509+         """ 
1510+         Fetches one row. 
1511+ 
1512+         Returns a dict if self._use_dict_result is True, otherwise 
1513+         returns tuple. 
1514+         """ 
15191515        if  self ._prefetch_hook  is  not None :
15201516            self ._prefetch_hook ()
15211517        if  self ._result  is  None  and  self ._result_set  is  not None :
@@ -1539,7 +1535,7 @@ def fetchone(self) -> dict | tuple | None:
15391535            else :
15401536                return  None 
15411537
1542-     def  fetchmany (self , size : int  |  None  =  None ) ->  list [tuple ]  |   list [ dict ]:
1538+     def  fetchmany (self , size : int  |  None  =  None ) ->  list [FetchRow ]:
15431539        """Fetches the number of specified rows.""" 
15441540        if  size  is  None :
15451541            size  =  self .arraysize 
@@ -1565,7 +1561,7 @@ def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
15651561
15661562        return  ret 
15671563
1568-     def  fetchall (self ) ->  list [tuple ]  |   list [ dict ]:
1564+     def  fetchall (self ) ->  list [FetchRow ]:
15691565        """Fetches all of the results.""" 
15701566        ret  =  []
15711567        while  True :
@@ -1728,20 +1724,31 @@ def wait_until_ready() -> None:
17281724            # Unset this function, so that we don't block anymore 
17291725            self ._prefetch_hook  =  None 
17301726
1731-             if  (
1732-                 self ._inner_cursor ._total_rowcount  ==  1 
1733-                 and  self ._inner_cursor .fetchall ()
1734-                 ==  [("Multiple statements executed successfully." ,)]
1727+             if  self ._inner_cursor ._total_rowcount  ==  1  and  _is_successful_multi_stmt (
1728+                 self ._inner_cursor .fetchall ()
17351729            ):
17361730                url  =  f"/queries/{ sfqid }  
17371731                ret  =  self ._connection .rest .request (url = url , method = "get" )
17381732                if  "data"  in  ret  and  "resultIds"  in  ret ["data" ]:
17391733                    self ._init_multi_statement_results (ret ["data" ])
17401734
1735+         def  _is_successful_multi_stmt (rows : list [Any ]) ->  bool :
1736+             if  len (rows ) !=  1 :
1737+                 return  False 
1738+             row  =  rows [0 ]
1739+             if  isinstance (row , tuple ):
1740+                 return  row  ==  ("Multiple statements executed successfully." ,)
1741+             elif  isinstance (row , dict ):
1742+                 return  row  ==  {
1743+                     "multiple statement execution" : "Multiple statements executed successfully." 
1744+                 }
1745+             else :
1746+                 return  False 
1747+ 
17411748        self .connection .get_query_status_throw_if_error (
17421749            sfqid 
17431750        )  # Trigger an exception if query failed 
1744-         self ._inner_cursor  =  SnowflakeCursor (self .connection )
1751+         self ._inner_cursor  =  self . __class__ (self .connection )
17451752        self ._sfqid  =  sfqid 
17461753        self ._prefetch_hook  =  wait_until_ready 
17471754
@@ -1925,14 +1932,53 @@ def _create_file_transfer_agent(
19251932        )
19261933
19271934
1928- class  DictCursor (SnowflakeCursor ):
1935+ class  SnowflakeCursor (SnowflakeCursorBase [tuple [Any , ...]]):
1936+     """Implementation of Cursor object that is returned from Connection.cursor() method. 
1937+ 
1938+     Attributes: 
1939+         description: A list of namedtuples about metadata for all columns. 
1940+         rowcount: The number of records updated or selected. If not clear, -1 is returned. 
1941+         rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be 
1942+             determined. 
1943+         sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support. 
1944+         sqlstate: Snowflake SQL State code. 
1945+         timestamp_output_format: Snowflake timestamp_output_format for timestamps. 
1946+         timestamp_ltz_output_format: Snowflake output format for LTZ timestamps. 
1947+         timestamp_tz_output_format: Snowflake output format for TZ timestamps. 
1948+         timestamp_ntz_output_format: Snowflake output format for NTZ timestamps. 
1949+         date_output_format: Snowflake output format for dates. 
1950+         time_output_format: Snowflake output format for times. 
1951+         timezone: Snowflake timezone. 
1952+         binary_output_format: Snowflake output format for binary fields. 
1953+         arraysize: The default number of rows fetched by fetchmany. 
1954+         connection: The connection object by which the cursor was created. 
1955+         errorhandle: The class that handles error handling. 
1956+         is_file_transfer: Whether, or not the current command is a put, or get. 
1957+     """ 
1958+ 
1959+     @property  
1960+     def  _use_dict_result (self ) ->  bool :
1961+         return  False 
1962+ 
1963+     def  fetchone (self ) ->  tuple [Any , ...] |  None :
1964+         row  =  self ._fetchone ()
1965+         if  not  (row  is  None  or  isinstance (row , tuple )):
1966+             raise  TypeError (f"fetchone got unexpected result: { row }  )
1967+         return  row 
1968+ 
1969+ 
1970+ class  DictCursor (SnowflakeCursorBase [dict [str , Any ]]):
19291971    """Cursor returning results in a dictionary.""" 
19301972
1931-     def  __init__ (self , connection ) ->  None :
1932-         super ().__init__ (
1933-             connection ,
1934-             use_dict_result = True ,
1935-         )
1973+     @property  
1974+     def  _use_dict_result (self ) ->  bool :
1975+         return  True 
1976+ 
1977+     def  fetchone (self ) ->  dict [str , Any ] |  None :
1978+         row  =  self ._fetchone ()
1979+         if  not  (row  is  None  or  isinstance (row , dict )):
1980+             raise  TypeError (f"fetchone got unexpected result: { row }  )
1981+         return  row 
19361982
19371983
19381984def  __getattr__ (name ):
0 commit comments