Skip to content

Commit e21b06d

Browse files
committed
Add support for executemany in DBAPI
1 parent 0430df3 commit e21b06d

File tree

3 files changed

+81
-3
lines changed

3 files changed

+81
-3
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import trino
2020
from tests.integration.conftest import trino_version
21-
from trino.exceptions import TrinoQueryError, TrinoUserError
21+
from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError
2222
from trino.transaction import IsolationLevel
2323

2424

@@ -124,6 +124,47 @@ def test_string_query_param(trino_connection):
124124
assert rows[0][0] == "six'"
125125

126126

127+
def test_execute_many(trino_connection):
128+
cur = trino_connection.cursor()
129+
cur.execute("CREATE TABLE memory.default.test_execute_many (key int, value varchar)")
130+
cur.fetchall()
131+
operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)"
132+
cur.executemany(operation, [(1, "value1")])
133+
cur.fetchall()
134+
cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key")
135+
rows = cur.fetchall()
136+
assert len(list(rows)) == 1
137+
assert rows[0] == [1, "value1"]
138+
139+
operation = "INSERT INTO memory.default.test_execute_many (key, value) VALUES (?, ?)"
140+
cur.executemany(operation, [(2, "value2"), (3, "value3")])
141+
cur.fetchall()
142+
143+
cur.execute("SELECT * FROM memory.default.test_execute_many ORDER BY key")
144+
rows = cur.fetchall()
145+
assert len(list(rows)) == 3
146+
assert rows[0] == [1, "value1"]
147+
assert rows[1] == [2, "value2"]
148+
assert rows[2] == [3, "value3"]
149+
150+
151+
def test_execute_many_without_params(trino_connection):
152+
cur = trino_connection.cursor()
153+
cur.execute("CREATE TABLE memory.default.test_execute_many_without_param (value varchar)")
154+
cur.fetchall()
155+
cur.executemany("INSERT INTO memory.default.test_execute_many_without_param (value) VALUES (?)", [])
156+
with pytest.raises(TrinoUserError) as e:
157+
cur.fetchall()
158+
assert "Incorrect number of parameters: expected 1 but found 0" in str(e.value)
159+
160+
161+
def test_execute_many_select(trino_connection):
162+
cur = trino_connection.cursor()
163+
with pytest.raises(NotSupportedError) as e:
164+
cur.executemany("SELECT ?, ?", [(1, "value1"), (2, "value2")])
165+
assert "Query must return update type" in str(e.value)
166+
167+
127168
def test_python_types_not_used_when_experimental_python_types_is_not_set(trino_connection):
128169
cur = trino_connection.cursor()
129170

trino/client.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,13 @@ def get_session_property_values(headers, header):
107107

108108

109109
class TrinoStatus(object):
110-
def __init__(self, id, stats, warnings, info_uri, next_uri, rows, columns=None):
110+
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
111111
self.id = id
112112
self.stats = stats
113113
self.warnings = warnings
114114
self.info_uri = info_uri
115115
self.next_uri = next_uri
116+
self.update_type = update_type
116117
self.rows = rows
117118
self.columns = columns
118119

@@ -448,6 +449,7 @@ def process(self, http_response) -> TrinoStatus:
448449
warnings=response.get("warnings", []),
449450
info_uri=response["infoUri"],
450451
next_uri=self._next_uri,
452+
update_type=response.get("updateType"),
451453
rows=response.get("data", []),
452454
columns=response.get("columns"),
453455
)
@@ -572,6 +574,7 @@ def __init__(
572574
self._finished = False
573575
self._cancelled = False
574576
self._request = request
577+
self._update_type = None
575578
self._sql = sql
576579
self._result = TrinoResult(self, experimental_python_types=experimental_python_types)
577580
self._response_headers = None
@@ -590,6 +593,10 @@ def columns(self):
590593
def stats(self):
591594
return self._stats
592595

596+
@property
597+
def update_type(self):
598+
return self._update_type
599+
593600
@property
594601
def warnings(self):
595602
return self._warnings
@@ -627,6 +634,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
627634

628635
def _update_state(self, status):
629636
self._stats.update(status.stats)
637+
self._update_type = status.update_type
630638
if status.columns:
631639
self._columns = status.columns
632640

trino/dbapi.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,12 @@ def info_uri(self):
243243
return self._query.info_uri
244244
return None
245245

246+
@property
247+
def update_type(self):
248+
if self._query is not None:
249+
return self._query.update_type
250+
return None
251+
246252
@property
247253
def description(self):
248254
if self._query.columns is None:
@@ -465,7 +471,30 @@ def execute(self, operation, params=None):
465471
return result
466472

467473
def executemany(self, operation, seq_of_params):
468-
raise trino.exceptions.NotSupportedError
474+
"""
475+
PEP-0249: Prepare a database operation (query or command) and then
476+
execute it against all parameter sequences or mappings found in the sequence seq_of_parameters.
477+
Modules are free to implement this method using multiple calls to
478+
the .execute() method or by using array operations to have the
479+
database process the sequence as a whole in one call.
480+
481+
Use of this method for an operation which produces one or more result
482+
sets constitutes undefined behavior, and the implementation is permitted (but not required)
483+
to raise an exception when it detects that a result set has been created by an invocation of the operation.
484+
485+
The same comments as for .execute() also apply accordingly to this method.
486+
487+
Return values are not defined.
488+
"""
489+
for parameters in seq_of_params[:-1]:
490+
self.execute(operation, parameters)
491+
self.fetchall()
492+
if self._query.update_type is None:
493+
raise NotSupportedError("Query must return update type")
494+
if seq_of_params:
495+
self.execute(operation, seq_of_params[-1])
496+
else:
497+
self.execute(operation)
469498

470499
def fetchone(self) -> Optional[List[Any]]:
471500
"""

0 commit comments

Comments
 (0)