Skip to content

Commit 57c9ead

Browse files
authored
Restore connection context if exception happens during rollback or commit (#1796)
* Add test cases exposing issue * Restore connection context if commit or rollback throws an exception
1 parent e762837 commit 57c9ead

File tree

3 files changed

+55
-23
lines changed

3 files changed

+55
-23
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Changelog
1414
Fixed
1515
^^^^^
1616
- Fix bug related to `Connector.div` in combined expressions. (#1794)
17+
- Fix recovery in case of database downtime (#1796)
1718

1819
Changed
1920
^^^^^^^

tests/test_transactions.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from unittest.mock import Mock
2+
13
from tests.testmodels import CharPkModel, Event, Team, Tournament
4+
from tortoise import connections
25
from tortoise.contrib import test
36
from tortoise.exceptions import OperationalError, TransactionManagementError
47
from tortoise.transactions import atomic, in_transaction
@@ -213,3 +216,27 @@ async def test_select_await_across_transaction_success(self):
213216
self.assertEqual(
214217
await Tournament.all().values("id", "name"), [{"id": obj.id, "name": "Test1"}]
215218
)
219+
220+
221+
@test.requireCapability(supports_transactions=True)
222+
class TestIsolatedTransactions(test.IsolatedTestCase):
223+
"""Running these in isolation because they mess with the global state of the connections."""
224+
225+
async def test_rollback_raising_exception(self):
226+
"""Tests that if a rollback raises an exception, the connection context is restored."""
227+
conn = connections.get("models")
228+
with self.assertRaisesRegex(ValueError, "rollback"):
229+
async with conn._in_transaction() as tx_conn:
230+
tx_conn.rollback = Mock(side_effect=ValueError("rollback"))
231+
raise ValueError("initial exception")
232+
233+
self.assertEqual(connections.get("models"), conn)
234+
235+
async def test_commit_raising_exception(self):
236+
"""Tests that if a commit raises an exception, the connection context is restored."""
237+
conn = connections.get("models")
238+
with self.assertRaisesRegex(ValueError, "commit"):
239+
async with conn._in_transaction() as tx_conn:
240+
tx_conn.commit = Mock(side_effect=ValueError("commit"))
241+
242+
self.assertEqual(connections.get("models"), conn)

tortoise/backends/base/client.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class TransactionContext(Generic[T_conn]):
246246
def __init__(self, connection: Any) -> None:
247247
self.connection = connection
248248
self.connection_name = connection.connection_name
249-
self.lock = getattr(connection, "_trxlock", None)
249+
self.lock = connection._trxlock
250250

251251
async def ensure_connection(self) -> None:
252252
if not self.connection._connection:
@@ -255,21 +255,23 @@ async def ensure_connection(self) -> None:
255255

256256
async def __aenter__(self) -> T_conn:
257257
await self.ensure_connection()
258-
await self.lock.acquire() # type:ignore
258+
await self.lock.acquire()
259259
self.token = connections.set(self.connection_name, self.connection)
260260
await self.connection.start()
261261
return self.connection
262262

263263
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
264-
if not self.connection._finalized:
265-
if exc_type:
266-
# Can't rollback a transaction that already failed.
267-
if exc_type is not TransactionManagementError:
268-
await self.connection.rollback()
269-
else:
270-
await self.connection.commit()
271-
connections.reset(self.token)
272-
self.lock.release() # type:ignore
264+
try:
265+
if not self.connection._finalized:
266+
if exc_type:
267+
# Can't rollback a transaction that already failed.
268+
if exc_type is not TransactionManagementError:
269+
await self.connection.rollback()
270+
else:
271+
await self.connection.commit()
272+
finally:
273+
connections.reset(self.token)
274+
self.lock.release()
273275

274276

275277
class TransactionContextPooled(TransactionContext):
@@ -287,16 +289,18 @@ async def __aenter__(self) -> T_conn:
287289
return self.connection
288290

289291
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
290-
if not self.connection._finalized:
291-
if exc_type:
292-
# Can't rollback a transaction that already failed.
293-
if exc_type is not TransactionManagementError:
294-
await self.connection.rollback()
295-
else:
296-
await self.connection.commit()
297-
if self.connection._parent._pool:
298-
await self.connection._parent._pool.release(self.connection._connection)
299-
connections.reset(self.token)
292+
try:
293+
if not self.connection._finalized:
294+
if exc_type:
295+
# Can't rollback a transaction that already failed.
296+
if exc_type is not TransactionManagementError:
297+
await self.connection.rollback()
298+
else:
299+
await self.connection.commit()
300+
finally:
301+
if self.connection._parent._pool:
302+
await self.connection._parent._pool.release(self.connection._connection)
303+
connections.reset(self.token)
300304

301305

302306
class NestedTransactionContext(TransactionContext):
@@ -313,11 +317,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
313317

314318
class NestedTransactionPooledContext(TransactionContext):
315319
async def __aenter__(self) -> T_conn:
316-
await self.lock.acquire() # type:ignore
320+
await self.lock.acquire()
317321
return self.connection
318322

319323
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
320-
self.lock.release() # type:ignore
324+
self.lock.release()
321325
if not self.connection._finalized:
322326
if exc_type:
323327
# Can't rollback a transaction that already failed.

0 commit comments

Comments
 (0)