Skip to content

Commit 2ad05fd

Browse files
committed
fix(types): respect types from SCSP
Rowset values are sent by the SCSP protocol with their sqlite3 detected type: values now respect this data type instead of the decltype Fixed several tests
1 parent efb53de commit 2ad05fd

File tree

7 files changed

+101
-56
lines changed

7 files changed

+101
-56
lines changed

src/sqlitecloud/datatypes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
from .resultset import SQLiteCloudResultSet
77

8+
# SQLite supported data types
9+
SQLiteDataTypes = Union[str, int, float, bytes, None]
10+
11+
812
# Basic types supported by SQLite Cloud APIs
913
SQLiteCloudDataTypes = Union[str, int, bool, Dict[Union[str, int], Any], bytes, None]
1014

@@ -249,6 +253,6 @@ class SQLiteCloudValue:
249253
"""
250254

251255
def __init__(self) -> None:
252-
self.value: Optional[str] = None
256+
self.value: Optional[SQLiteCloudDataTypes] = None
253257
self.len: int = 0
254258
self.cellsize: int = 0

src/sqlitecloud/dbapi2.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
SQLiteCloudAccount,
2424
SQLiteCloudConfig,
2525
SQLiteCloudConnect,
26-
SQLiteCloudDataTypes,
2726
SQLiteCloudException,
2827
)
2928
from sqlitecloud.driver import Driver
@@ -145,7 +144,7 @@ def register_adapter(
145144

146145
class Connection:
147146
"""
148-
Represents a DB-APi 2.0 connection to the SQLite Cloud database.
147+
Represents a DB-API 2.0 connection to the SQLite Cloud database.
149148
150149
Args:
151150
SQLiteCloud_connection (SQLiteCloudConnect): The SQLite Cloud connection object.
@@ -155,13 +154,15 @@ class Connection:
155154
SQLiteCloud_connection (SQLiteCloudConnect): The SQLite Cloud connection object.
156155
"""
157156

158-
row_factory: Optional[Callable[["Cursor", Tuple], object]] = None
159-
text_factory: Union[Type[Union[str, bytes]], Callable[[bytes], object]] = str
160-
161157
def __init__(self, sqlitecloud_connection: SQLiteCloudConnect) -> None:
162158
self._driver = Driver()
163-
self.row_factory = None
164159
self.sqlitecloud_connection = sqlitecloud_connection
160+
161+
self.row_factory: Optional[Callable[["Cursor", Tuple], object]] = None
162+
self.text_factory: Union[
163+
Type[Union[str, bytes]], Callable[[bytes], object]
164+
] = str
165+
165166
self.detect_types = 0
166167

167168
@property
@@ -177,17 +178,15 @@ def sqlcloud_connection(self) -> SQLiteCloudConnect:
177178
def execute(
178179
self,
179180
sql: str,
180-
parameters: Union[
181-
Tuple[SQLiteCloudDataTypes], Dict[Union[str, int], SQLiteCloudDataTypes]
182-
] = (),
181+
parameters: Union[Tuple[any], Dict[Union[str, int], any]] = (),
183182
) -> "Cursor":
184183
"""
185184
Shortcut for cursor.execute().
186185
See the docstring of Cursor.execute() for more information.
187186
188187
Args:
189188
sql (str): The SQL query to execute.
190-
parameters (Union[Tuple[SQLiteCloudDataTypes], Dict[Union[str, int], SQLiteCloudDataTypes]]):
189+
parameters (Union[Tuple[any], Dict[Union[str, int], any]]):
191190
The parameters to be used in the query. It can be a tuple or a dictionary. (Default ())
192191
conn (SQLiteCloudConnect): The connection object to use for executing the query.
193192
@@ -200,19 +199,15 @@ def execute(
200199
def executemany(
201200
self,
202201
sql: str,
203-
seq_of_parameters: Iterable[
204-
Union[
205-
Tuple[SQLiteCloudDataTypes], Dict[Union[str, int], SQLiteCloudDataTypes]
206-
]
207-
],
202+
seq_of_parameters: Iterable[Union[Tuple[any], Dict[Union[str, int], any]]],
208203
) -> "Cursor":
209204
"""
210205
Shortcut for cursor.executemany().
211206
See the docstring of Cursor.executemany() for more information.
212207
213208
Args:
214209
sql (str): The SQL statement to execute.
215-
seq_of_parameters (Iterable[Union[Tuple[SQLiteCloudDataTypes], Dict[Union[str, int], SQLiteCloudDataTypes]]]):
210+
seq_of_parameters (Iterable[Union[Tuple[any], Dict[Union[str, int], any]]]):
216211
The sequence of parameter sets to bind to the SQL statement.
217212
218213
Returns:
@@ -385,9 +380,7 @@ def close(self) -> None:
385380
def execute(
386381
self,
387382
sql: str,
388-
parameters: Union[
389-
Tuple[SQLiteCloudDataTypes], Dict[Union[str, int], SQLiteCloudDataTypes]
390-
] = (),
383+
parameters: Union[Tuple[any], Dict[Union[str, int], any]] = (),
391384
) -> "Cursor":
392385
"""
393386
Prepare and execute a SQL statement (either a query or command) to the SQLite Cloud database.
@@ -405,7 +398,7 @@ def execute(
405398
406399
Args:
407400
sql (str): The SQL query to execute.
408-
parameters (Union[Tuple[SQLiteCloudDataTypes], Dict[Union[str, int], SQLiteCloudDataTypes]]):
401+
parameters (Union[Tuple[any], Dict[Union[str, int], any]]):
409402
The parameters to be used in the query. It can be a tuple or a dictionary. (Default ())
410403
conn (SQLiteCloudConnect): The connection object to use for executing the query.
411404
@@ -428,11 +421,7 @@ def execute(
428421
def executemany(
429422
self,
430423
sql: str,
431-
seq_of_parameters: Iterable[
432-
Union[
433-
Tuple[SQLiteCloudDataTypes], Dict[Union[str, int], SQLiteCloudDataTypes]
434-
]
435-
],
424+
seq_of_parameters: Iterable[Union[Tuple[any], Dict[Union[str, int], any]]],
436425
) -> "Cursor":
437426
"""
438427
Executes a SQL statement multiple times, each with a different set of parameters.
@@ -441,7 +430,7 @@ def executemany(
441430
442431
Args:
443432
sql (str): The SQL statement to execute.
444-
seq_of_parameters (Iterable[Union[Tuple[SQLiteCloudDataTypes], Dict[Union[str, int], SQLiteCloudDataTypes]]]):
433+
seq_of_parameters (Iterable[Union[Tuple[any], Dict[Union[str, int], any]]]):
445434
The sequence of parameter sets to bind to the SQL statement.
446435
447436
Returns:
@@ -564,10 +553,11 @@ def _get_value(self, row: int, col: int) -> Optional[Any]:
564553

565554
if self._connection.text_factory is bytes:
566555
return value.encode("utf-8")
567-
if self._connection.text_factory is str:
568-
return value
569-
# callable
570-
return self._connection.text_factory(value.encode("utf-8"))
556+
if self._connection.text_factory is not str and callable(
557+
self._connection.text_factory
558+
):
559+
return self._connection.text_factory(value.encode("utf-8"))
560+
return value
571561

572562
return self._resultset.get_value(row, col)
573563

src/sqlitecloud/driver.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,7 @@ def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQLiteCloudVal
862862
if cellsize is not None:
863863
cellsize = 2
864864

865+
sqlitecloud_value.value = None
865866
sqlitecloud_value.len = len
866867
sqlitecloud_value.cellsize = cellsize
867868

@@ -877,7 +878,11 @@ def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQLiteCloudVal
877878
len = nlen - 2
878879
cellsize = nlen
879880

880-
sqlitecloud_value.value = (buffer[index + 1 : index + 1 + len]).decode()
881+
value = (buffer[index + 1 : index + 1 + len]).decode()
882+
883+
sqlitecloud_value.value = (
884+
int(value) if c == SQLITECLOUD_CMD.INT.value else float(value)
885+
)
881886
sqlitecloud_value.len
882887
sqlitecloud_value.cellsize = cellsize
883888

@@ -886,7 +891,12 @@ def _internal_parse_value(self, buffer: bytes, index: int = 0) -> SQLiteCloudVal
886891
len = blen - 1 if c == SQLITECLOUD_CMD.ZEROSTRING.value else blen
887892
cellsize = blen + cstart - index
888893

889-
sqlitecloud_value.value = (buffer[cstart : cstart + len]).decode()
894+
value = buffer[cstart : cstart + len]
895+
896+
if c == SQLITECLOUD_CMD.STRING.value or c == SQLITECLOUD_CMD.ZEROSTRING.value:
897+
value = value.decode()
898+
899+
sqlitecloud_value.value = value
890900
sqlitecloud_value.len = len
891901
sqlitecloud_value.cellsize = cellsize
892902

src/tests/integration/test_client.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,8 @@ def test_array(self, sqlitecloud_connection):
295295
assert isinstance(result_array, list)
296296
assert len(result_array) == 5
297297
assert result_array[0] == "Hello World"
298-
assert result_array[1] == "123456"
299-
assert result_array[2] == "3.1415"
298+
assert result_array[1] == 123456
299+
assert result_array[2] == 3.1415
300300
assert result_array[3] is None
301301

302302
def test_rowset(self, sqlitecloud_connection):
@@ -310,6 +310,19 @@ def test_rowset(self, sqlitecloud_connection):
310310
assert result.get_name(0) == "key"
311311
assert result.get_name(1) == "value"
312312

313+
def test_rowset_data_types(self, sqlitecloud_connection):
314+
connection, client = sqlitecloud_connection
315+
316+
bindings = ("hello world", 15175, 3.14, b"bytes world", None)
317+
result = client.exec_statement("SELECT ?, ?, ?, ?, ?", bindings, connection)
318+
319+
assert SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET == result.tag
320+
assert result.get_value(0, 0) == "hello world"
321+
assert result.get_value(0, 1) == 15175
322+
assert result.get_value(0, 2) == 3.14
323+
assert result.get_value(0, 3) == b"bytes world"
324+
assert result.get_value(0, 4) is None
325+
313326
def test_max_rows_option(self):
314327
account = SQLiteCloudAccount()
315328
account.hostname = os.getenv("SQLITE_HOST")
@@ -419,7 +432,7 @@ def test_serialized_operations(self, sqlitecloud_connection):
419432
assert 2 == rowset.ncols
420433
assert "count" == rowset.get_name(0)
421434
assert "string" == rowset.get_name(1)
422-
assert str(i) == rowset.get_value(0, 0)
435+
assert i == rowset.get_value(0, 0)
423436
assert rowset.version in [1, 2]
424437

425438
def test_query_timeout(self):
@@ -504,7 +517,7 @@ def test_select_results_with_no_column_name(self, sqlitecloud_connection):
504517
assert rowset.ncols == 2
505518
assert rowset.get_name(0) == "42"
506519
assert rowset.get_name(1) == "'hello'"
507-
assert rowset.get_value(0, 0) == "42"
520+
assert rowset.get_value(0, 0) == 42
508521
assert rowset.get_value(0, 1) == "hello"
509522

510523
def test_select_long_formatted_string(self, sqlitecloud_connection):

src/tests/integration/test_sqlite3_parity.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_autocommit_mode_enabled_by_default(
266266

267267
connections = [
268268
(sqlitecloud_dbapi2_connection, next(get_sqlitecloud_dbapi2_connection())),
269-
(sqlite3_connection, next(self.get_sqlite3_connection())),
269+
(sqlite3_connection, next(get_sqlite3_connection())),
270270
]
271271

272272
for (connection, control_connection) in connections:
@@ -289,7 +289,7 @@ def test_explicit_transaction_to_commit(
289289

290290
connections = [
291291
(sqlitecloud_dbapi2_connection, next(get_sqlitecloud_dbapi2_connection())),
292-
(sqlite3_connection, next(self.get_sqlite3_connection())),
292+
(sqlite3_connection, next(get_sqlite3_connection())),
293293
]
294294

295295
for (connection, control_connection) in connections:
@@ -333,8 +333,7 @@ def test_text_factory_with_default_string(
333333
self, sqlitecloud_dbapi2_connection, sqlite3_connection
334334
):
335335
for connection in [sqlitecloud_dbapi2_connection, sqlite3_connection]:
336-
# by default is string
337-
# connection.text_factory = str
336+
# by default is string: connection.text_factory = str
338337

339338
austria = "\xd6sterreich"
340339

src/tests/unit/test_dbapi2.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,14 @@ def test_rowcount_with_no_resultset(self, mocker):
105105
assert cursor.rowcount == -1
106106

107107
def test_execute_escaped(self, mocker: MockerFixture):
108-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
108+
connection = mocker.patch("sqlitecloud.Connection")
109+
apply_adapter_mock = mocker.patch.object(connection, "_apply_adapter")
110+
apply_adapter_mock.return_value = "John's"
111+
109112
execute_mock = mocker.patch.object(Driver, "execute")
110113

114+
cursor = Cursor(connection)
115+
111116
sql = "SELECT * FROM users WHERE name = ?"
112117
parameters = ("John's",)
113118

@@ -157,7 +162,10 @@ def test_fetchone_with_result(self, mocker):
157162
assert cursor.fetchone() is None
158163

159164
def test_fetchone_with_rowset(self, mocker):
160-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
165+
connection = mocker.patch("sqlitecloud.Connection")
166+
connection.text_factory = str
167+
168+
cursor = Cursor(connection)
161169

162170
result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET)
163171
result.ncols = 1
@@ -195,7 +203,10 @@ def test_fetchmany_with_result(self, mocker):
195203
assert cursor.fetchmany() == []
196204

197205
def test_fetchmany_with_rowset_and_default_size(self, mocker):
198-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
206+
connection = mocker.patch("sqlitecloud.Connection")
207+
connection.text_factory = str
208+
209+
cursor = Cursor(connection)
199210

200211
result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET)
201212
result.ncols = 1
@@ -207,7 +218,10 @@ def test_fetchmany_with_rowset_and_default_size(self, mocker):
207218
assert cursor.fetchmany(None) == [("myname1",)]
208219

209220
def test_fetchmany_twice_to_retrieve_whole_rowset(self, mocker):
210-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
221+
connection = mocker.patch("sqlitecloud.Connection")
222+
connection.text_factory = str
223+
224+
cursor = Cursor(connection)
211225

212226
result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET)
213227
result.ncols = 1
@@ -220,7 +234,10 @@ def test_fetchmany_twice_to_retrieve_whole_rowset(self, mocker):
220234
assert cursor.fetchmany() == []
221235

222236
def test_fetchmany_with_size_higher_than_rowcount(self, mocker):
223-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
237+
connection = mocker.patch("sqlitecloud.Connection")
238+
connection.text_factory = str
239+
240+
cursor = Cursor(connection)
224241

225242
result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET)
226243
result.ncols = 1
@@ -245,7 +262,10 @@ def test_fetchall_with_result(self, mocker):
245262
assert cursor.fetchall() == []
246263

247264
def test_fetchall_with_rowset(self, mocker):
248-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
265+
connection = mocker.patch("sqlitecloud.Connection")
266+
connection.text_factory = str
267+
268+
cursor = Cursor(connection)
249269

250270
result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET)
251271
result.ncols = 1
@@ -257,7 +277,10 @@ def test_fetchall_with_rowset(self, mocker):
257277
assert cursor.fetchall() == [("myname1",), ("myname2",), ("myname3",)]
258278

259279
def test_fetchall_twice_and_expect_empty_list(self, mocker):
260-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
280+
connection = mocker.patch("sqlitecloud.Connection")
281+
connection.text_factory = str
282+
283+
cursor = Cursor(connection)
261284

262285
result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET)
263286
result.ncols = 1
@@ -270,7 +293,10 @@ def test_fetchall_twice_and_expect_empty_list(self, mocker):
270293
assert cursor.fetchall() == []
271294

272295
def test_fetchall_to_return_remaining_rows(self, mocker):
273-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
296+
connection = mocker.patch("sqlitecloud.Connection")
297+
connection.text_factory = str
298+
299+
cursor = Cursor(connection)
274300

275301
result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET)
276302
result.ncols = 1
@@ -283,7 +309,10 @@ def test_fetchall_to_return_remaining_rows(self, mocker):
283309
assert cursor.fetchall() == [("myname2",)]
284310

285311
def test_iterator(self, mocker):
286-
cursor = Cursor(mocker.patch("sqlitecloud.Connection"))
312+
connection = mocker.patch("sqlitecloud.Connection")
313+
connection.text_factory = str
314+
315+
cursor = Cursor(connection)
287316

288317
result = SQLiteCloudResult(SQLITECLOUD_RESULT_TYPE.RESULT_ROWSET)
289318
result.ncols = 1

0 commit comments

Comments
 (0)