Skip to content

Commit efb53de

Browse files
committed
feat(adapters): text factory and adapters
1 parent e65a892 commit efb53de

File tree

6 files changed

+358
-37
lines changed

6 files changed

+358
-37
lines changed

src/sqlitecloud/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# the classes and functions from the dbapi2 module.
33
# eg: sqlite3.connect() -> sqlitecloud.connect()
44
#
5-
from .dbapi2 import Connection, Cursor, connect
5+
from .dbapi2 import Connection, Cursor, connect, register_adapter
66

7-
__all__ = ["VERSION", "Connection", "Cursor", "connect"]
7+
__all__ = ["VERSION", "Connection", "Cursor", "connect", "register_adapter"]
88

99
VERSION = "0.0.79"

src/sqlitecloud/dbapi2.py

Lines changed: 121 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# https://peps.python.org/pep-0249/
55
#
66
import logging
7+
from datetime import date, datetime
78
from typing import (
89
Any,
910
Callable,
@@ -13,6 +14,7 @@
1314
List,
1415
Optional,
1516
Tuple,
17+
Type,
1618
Union,
1719
overload,
1820
)
@@ -25,7 +27,14 @@
2527
SQLiteCloudException,
2628
)
2729
from sqlitecloud.driver import Driver
28-
from sqlitecloud.resultset import SQLITECLOUD_RESULT_TYPE, SQLiteCloudResult
30+
from sqlitecloud.resultset import (
31+
SQLITECLOUD_RESULT_TYPE,
32+
SQLITECLOUD_VALUE_TYPE,
33+
SQLiteCloudResult,
34+
)
35+
36+
# SQLite supported types
37+
SQLiteTypes = Union[int, float, str, bytes, None]
2938

3039
# Question mark style, e.g. ...WHERE name=?
3140
# Module also supports Named style, e.g. ...WHERE name=:name
@@ -37,6 +46,14 @@
3746
# DB API level
3847
apilevel = "2.0"
3948

49+
# These constants are meant to be used with the detect_types
50+
# parameter of the connect() function
51+
PARSE_DECLTYPES = 1
52+
PARSE_COLNAMES = 2
53+
54+
# Adapter registry to convert Python types to SQLite types
55+
adapters = {}
56+
4057

4158
@overload
4259
def connect(connection_str: str) -> "Connection":
@@ -80,6 +97,7 @@ def connect(
8097
def connect(
8198
connection_info: Union[str, SQLiteCloudAccount],
8299
config: Optional[SQLiteCloudConfig] = None,
100+
detect_types: int = 0,
83101
) -> "Connection":
84102
"""
85103
Establishes a connection to the SQLite Cloud database.
@@ -110,6 +128,21 @@ def connect(
110128
)
111129

112130

131+
def register_adapter(
132+
pytype: Type, adapter_callable: Callable[[object], SQLiteTypes]
133+
) -> None:
134+
"""
135+
Registers a callable to convert the type into one of the supported SQLite types.
136+
137+
Args:
138+
type (Type): The type to convert.
139+
callable (Callable): The callable that converts the type into a supported
140+
SQLite supported type.
141+
"""
142+
global adapters
143+
adapters[pytype] = adapter_callable
144+
145+
113146
class Connection:
114147
"""
115148
Represents a DB-APi 2.0 connection to the SQLite Cloud database.
@@ -123,11 +156,13 @@ class Connection:
123156
"""
124157

125158
row_factory: Optional[Callable[["Cursor", Tuple], object]] = None
159+
text_factory: Union[Type[Union[str, bytes]], Callable[[bytes], object]] = str
126160

127161
def __init__(self, sqlitecloud_connection: SQLiteCloudConnect) -> None:
128162
self._driver = Driver()
129163
self.row_factory = None
130164
self.sqlitecloud_connection = sqlitecloud_connection
165+
self.detect_types = 0
131166

132167
@property
133168
def sqlcloud_connection(self) -> SQLiteCloudConnect:
@@ -243,6 +278,21 @@ def cursor(self):
243278
cursor.row_factory = self.row_factory
244279
return cursor
245280

281+
def _apply_adapter(self, value: object) -> SQLiteTypes:
282+
"""
283+
Applies the adapter to convert the Python type into a SQLite supported type.
284+
285+
Args:
286+
value (object): The Python type to convert.
287+
288+
Returns:
289+
SQLiteTypes: The SQLite supported type.
290+
"""
291+
if type(value) in adapters:
292+
return adapters[type(value)](value)
293+
294+
return value
295+
246296
def __del__(self) -> None:
247297
self.close()
248298

@@ -364,6 +414,8 @@ def execute(
364414
"""
365415
self._ensure_connection()
366416

417+
parameters = self._adapt_parameters(parameters)
418+
367419
prepared_statement = self._driver.prepare_statement(sql, parameters)
368420
result = self._driver.execute(
369421
prepared_statement, self.connection.sqlcloud_connection
@@ -492,12 +544,37 @@ def _ensure_connection(self):
492544
if not self._connection:
493545
raise SQLiteCloudException("The cursor is closed.")
494546

547+
def _adapt_parameters(self, parameters: Union[Dict, Tuple]) -> Union[Dict, Tuple]:
548+
if isinstance(parameters, dict):
549+
params = {}
550+
for i in parameters.keys():
551+
params[i] = self._connection._apply_adapter(parameters[i])
552+
return params
553+
554+
return tuple(self._connection._apply_adapter(p) for p in parameters)
555+
556+
def _get_value(self, row: int, col: int) -> Optional[Any]:
557+
if not self._is_result_rowset():
558+
return None
559+
560+
# Convert TEXT type with text_factory
561+
decltype = self._resultset.get_decltype(col)
562+
if decltype is None or decltype == SQLITECLOUD_VALUE_TYPE.TEXT.value:
563+
value = self._resultset.get_value(row, col, False)
564+
565+
if self._connection.text_factory is bytes:
566+
return value.encode("utf-8")
567+
if self._connection.text_factory is str:
568+
return value
569+
# callable
570+
return self._connection.text_factory(value.encode("utf-8"))
571+
572+
return self._resultset.get_value(row, col)
573+
495574
def __iter__(self) -> "Cursor":
496575
return self
497576

498577
def __next__(self) -> Optional[Tuple[Any]]:
499-
self._ensure_connection()
500-
501578
if (
502579
not self._resultset.is_result
503580
and self._resultset.data
@@ -506,9 +583,49 @@ def __next__(self) -> Optional[Tuple[Any]]:
506583
out: Tuple[Any] = ()
507584

508585
for col in range(self._resultset.ncols):
509-
out += (self._resultset.get_value(self._iter_row, col),)
586+
out += (self._get_value(self._iter_row, col),)
510587
self._iter_row += 1
511588

512589
return self._call_row_factory(out)
513590

514591
raise StopIteration
592+
593+
594+
def register_adapters_and_converters():
595+
"""
596+
sqlite3 default adapters and converters.
597+
598+
This code is adapted from the Python standard library's sqlite3 module.
599+
The Python standard library is licensed under the Python Software Foundation License.
600+
Source: https://github.com/python/cpython/blob/3.6/Lib/sqlite3/dbapi2.py
601+
"""
602+
603+
def adapt_date(val):
604+
return val.isoformat()
605+
606+
def adapt_datetime(val):
607+
return val.isoformat(" ")
608+
609+
def convert_date(val):
610+
return datetime.date(*map(int, val.split(b"-")))
611+
612+
def convert_timestamp(val):
613+
datepart, timepart = val.split(b" ")
614+
year, month, day = map(int, datepart.split(b"-"))
615+
timepart_full = timepart.split(b".")
616+
hours, minutes, seconds = map(int, timepart_full[0].split(b":"))
617+
if len(timepart_full) == 2:
618+
microseconds = int("{:0<6.6}".format(timepart_full[1].decode()))
619+
else:
620+
microseconds = 0
621+
622+
val = datetime.datetime(year, month, day, hours, minutes, seconds, microseconds)
623+
return val
624+
625+
register_adapter(date, adapt_date)
626+
register_adapter(datetime, adapt_datetime)
627+
# register_converter("date", convert_date)
628+
# register_converter("timestamp", convert_timestamp)
629+
630+
631+
register_adapters_and_converters()

src/sqlitecloud/resultset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ def get_name(self, col: int) -> Optional[str]:
7373
return None
7474
return self.colname[col]
7575

76+
def get_decltype(self, col: int) -> Optional[str]:
77+
if col < 0 or col >= self.ncols or col >= len(self.decltype):
78+
return None
79+
80+
return self.decltype[col]
81+
7682
def _convert(self, value: str, col: int) -> any:
7783
if col < 0 or col >= len(self.decltype):
7884
return value

src/tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sqlite3
23

34
import pytest
45
from dotenv import load_dotenv
@@ -59,3 +60,14 @@ def get_sqlitecloud_dbapi2_connection():
5960
yield connection
6061

6162
connection.close()
63+
64+
65+
def get_sqlite3_connection():
66+
# set isolation_level=None to enable autocommit
67+
# and to be aligned with the behavior of SQLite Cloud
68+
connection = sqlite3.connect(
69+
os.path.join(os.path.dirname(__file__), "./assets/chinook.sqlite"),
70+
isolation_level=None,
71+
)
72+
yield connection
73+
connection.close()

src/tests/integration/test_dbapi2.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -246,21 +246,3 @@ def test_row_factory(self, sqlitecloud_dbapi2_connection):
246246
assert row["AlbumId"] == 1
247247
assert row["Title"] == "For Those About To Rock We Salute You"
248248
assert row["ArtistId"] == 1
249-
250-
def test_commit_without_any_transaction_does_not_raise_exception(
251-
self, sqlitecloud_dbapi2_connection
252-
):
253-
connection = sqlitecloud_dbapi2_connection
254-
255-
connection.commit()
256-
257-
assert True
258-
259-
def test_rollback_without_any_transaction_does_not_raise_exception(
260-
self, sqlitecloud_dbapi2_connection
261-
):
262-
connection = sqlitecloud_dbapi2_connection
263-
264-
connection.rollback()
265-
266-
assert True

0 commit comments

Comments
 (0)