44# https://peps.python.org/pep-0249/
55#
66import logging
7+ from datetime import date , datetime
78from typing import (
89 Any ,
910 Callable ,
1314 List ,
1415 Optional ,
1516 Tuple ,
17+ Type ,
1618 Union ,
1719 overload ,
1820)
2527 SQLiteCloudException ,
2628)
2729from sqlitecloud .driver import Driver
28- from sqlitecloud .resultset import SQLITECLOUD_RESULT_TYPE , SQLiteCloudResult
30+ from sqlitecloud .resultset import (
31+ SQLITECLOUD_RESULT_TYPE ,
32+ SQLITECLOUD_VALUE_TYPE ,
33+ SQLiteCloudResult ,
34+ )
35+
36+ # SQLite supported types
37+ SQLiteTypes = Union [int , float , str , bytes , None ]
2938
3039# Question mark style, e.g. ...WHERE name=?
3140# Module also supports Named style, e.g. ...WHERE name=:name
3746# DB API level
3847apilevel = "2.0"
3948
49+ # These constants are meant to be used with the detect_types
50+ # parameter of the connect() function
51+ PARSE_DECLTYPES = 1
52+ PARSE_COLNAMES = 2
53+
54+ # Adapter registry to convert Python types to SQLite types
55+ adapters = {}
56+
4057
4158@overload
4259def connect (connection_str : str ) -> "Connection" :
@@ -80,6 +97,7 @@ def connect(
8097def connect (
8198 connection_info : Union [str , SQLiteCloudAccount ],
8299 config : Optional [SQLiteCloudConfig ] = None ,
100+ detect_types : int = 0 ,
83101) -> "Connection" :
84102 """
85103 Establishes a connection to the SQLite Cloud database.
@@ -110,6 +128,21 @@ def connect(
110128 )
111129
112130
131+ def register_adapter (
132+ pytype : Type , adapter_callable : Callable [[object ], SQLiteTypes ]
133+ ) -> None :
134+ """
135+ Registers a callable to convert the type into one of the supported SQLite types.
136+
137+ Args:
138+ type (Type): The type to convert.
139+ callable (Callable): The callable that converts the type into a supported
140+ SQLite supported type.
141+ """
142+ global adapters
143+ adapters [pytype ] = adapter_callable
144+
145+
113146class Connection :
114147 """
115148 Represents a DB-APi 2.0 connection to the SQLite Cloud database.
@@ -123,11 +156,13 @@ class Connection:
123156 """
124157
125158 row_factory : Optional [Callable [["Cursor" , Tuple ], object ]] = None
159+ text_factory : Union [Type [Union [str , bytes ]], Callable [[bytes ], object ]] = str
126160
127161 def __init__ (self , sqlitecloud_connection : SQLiteCloudConnect ) -> None :
128162 self ._driver = Driver ()
129163 self .row_factory = None
130164 self .sqlitecloud_connection = sqlitecloud_connection
165+ self .detect_types = 0
131166
132167 @property
133168 def sqlcloud_connection (self ) -> SQLiteCloudConnect :
@@ -243,6 +278,21 @@ def cursor(self):
243278 cursor .row_factory = self .row_factory
244279 return cursor
245280
281+ def _apply_adapter (self , value : object ) -> SQLiteTypes :
282+ """
283+ Applies the adapter to convert the Python type into a SQLite supported type.
284+
285+ Args:
286+ value (object): The Python type to convert.
287+
288+ Returns:
289+ SQLiteTypes: The SQLite supported type.
290+ """
291+ if type (value ) in adapters :
292+ return adapters [type (value )](value )
293+
294+ return value
295+
246296 def __del__ (self ) -> None :
247297 self .close ()
248298
@@ -364,6 +414,8 @@ def execute(
364414 """
365415 self ._ensure_connection ()
366416
417+ parameters = self ._adapt_parameters (parameters )
418+
367419 prepared_statement = self ._driver .prepare_statement (sql , parameters )
368420 result = self ._driver .execute (
369421 prepared_statement , self .connection .sqlcloud_connection
@@ -492,12 +544,37 @@ def _ensure_connection(self):
492544 if not self ._connection :
493545 raise SQLiteCloudException ("The cursor is closed." )
494546
547+ def _adapt_parameters (self , parameters : Union [Dict , Tuple ]) -> Union [Dict , Tuple ]:
548+ if isinstance (parameters , dict ):
549+ params = {}
550+ for i in parameters .keys ():
551+ params [i ] = self ._connection ._apply_adapter (parameters [i ])
552+ return params
553+
554+ return tuple (self ._connection ._apply_adapter (p ) for p in parameters )
555+
556+ def _get_value (self , row : int , col : int ) -> Optional [Any ]:
557+ if not self ._is_result_rowset ():
558+ return None
559+
560+ # Convert TEXT type with text_factory
561+ decltype = self ._resultset .get_decltype (col )
562+ if decltype is None or decltype == SQLITECLOUD_VALUE_TYPE .TEXT .value :
563+ value = self ._resultset .get_value (row , col , False )
564+
565+ if self ._connection .text_factory is bytes :
566+ return value .encode ("utf-8" )
567+ if self ._connection .text_factory is str :
568+ return value
569+ # callable
570+ return self ._connection .text_factory (value .encode ("utf-8" ))
571+
572+ return self ._resultset .get_value (row , col )
573+
495574 def __iter__ (self ) -> "Cursor" :
496575 return self
497576
498577 def __next__ (self ) -> Optional [Tuple [Any ]]:
499- self ._ensure_connection ()
500-
501578 if (
502579 not self ._resultset .is_result
503580 and self ._resultset .data
@@ -506,9 +583,49 @@ def __next__(self) -> Optional[Tuple[Any]]:
506583 out : Tuple [Any ] = ()
507584
508585 for col in range (self ._resultset .ncols ):
509- out += (self ._resultset . get_value (self ._iter_row , col ),)
586+ out += (self ._get_value (self ._iter_row , col ),)
510587 self ._iter_row += 1
511588
512589 return self ._call_row_factory (out )
513590
514591 raise StopIteration
592+
593+
594+ def register_adapters_and_converters ():
595+ """
596+ sqlite3 default adapters and converters.
597+
598+ This code is adapted from the Python standard library's sqlite3 module.
599+ The Python standard library is licensed under the Python Software Foundation License.
600+ Source: https://github.com/python/cpython/blob/3.6/Lib/sqlite3/dbapi2.py
601+ """
602+
603+ def adapt_date (val ):
604+ return val .isoformat ()
605+
606+ def adapt_datetime (val ):
607+ return val .isoformat (" " )
608+
609+ def convert_date (val ):
610+ return datetime .date (* map (int , val .split (b"-" )))
611+
612+ def convert_timestamp (val ):
613+ datepart , timepart = val .split (b" " )
614+ year , month , day = map (int , datepart .split (b"-" ))
615+ timepart_full = timepart .split (b"." )
616+ hours , minutes , seconds = map (int , timepart_full [0 ].split (b":" ))
617+ if len (timepart_full ) == 2 :
618+ microseconds = int ("{:0<6.6}" .format (timepart_full [1 ].decode ()))
619+ else :
620+ microseconds = 0
621+
622+ val = datetime .datetime (year , month , day , hours , minutes , seconds , microseconds )
623+ return val
624+
625+ register_adapter (date , adapt_date )
626+ register_adapter (datetime , adapt_datetime )
627+ # register_converter("date", convert_date)
628+ # register_converter("timestamp", convert_timestamp)
629+
630+
631+ register_adapters_and_converters ()
0 commit comments