Skip to content

Commit 26793e9

Browse files
committed
MySQL connection should be closed correctly now, see pytest-dev#26 + tests
1 parent 6d1ad22 commit 26793e9

File tree

2 files changed

+63
-52
lines changed

2 files changed

+63
-52
lines changed

peewee_async.py

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def select(query):
564564
except GeneratorExit:
565565
pass
566566

567-
cursor.release()
567+
yield from cursor.release
568568
return result
569569

570570

@@ -584,7 +584,7 @@ def insert(query):
584584
result = yield from query.database.last_insert_id_async(
585585
cursor, query.model_class)
586586

587-
cursor.release()
587+
yield from cursor.release
588588
return result
589589

590590

@@ -599,7 +599,7 @@ def update(query):
599599
cursor = yield from _execute_query_async(query)
600600
rowcount = cursor.rowcount
601601

602-
cursor.release()
602+
yield from cursor.release
603603
return rowcount
604604

605605

@@ -614,7 +614,7 @@ def delete(query):
614614
cursor = yield from _execute_query_async(query)
615615
rowcount = cursor.rowcount
616616

617-
cursor.release()
617+
yield from cursor.release
618618
return rowcount
619619

620620

@@ -650,7 +650,7 @@ def scalar(query, as_tuple=False):
650650
cursor = yield from _execute_query_async(query)
651651
row = yield from cursor.fetchone()
652652

653-
cursor.release()
653+
yield from cursor.release
654654
if row and not as_tuple:
655655
return row[0]
656656
else:
@@ -672,7 +672,7 @@ def raw_query(query):
672672
except GeneratorExit:
673673
pass
674674

675-
cursor.release()
675+
yield from cursor.release
676676
return result
677677

678678

@@ -983,36 +983,35 @@ def connect(self):
983983
**self.connect_kwargs)
984984

985985
@asyncio.coroutine
986-
def cursor(self, conn=None, *args, **kwargs):
987-
"""Get cursor for connection from pool.
986+
def close(self):
987+
"""Terminate all pool connections.
988988
"""
989-
if conn is None:
990-
# Acquire connection with cursor, once cursor is released
991-
# connection is also released to pool:
989+
self.pool.terminate()
990+
yield from self.pool.wait_closed()
992991

992+
@asyncio.coroutine
993+
def cursor(self, conn=None, *args, **kwargs):
994+
"""Get a cursor for the specified transaction connection
995+
or acquire from the pool.
996+
"""
997+
in_transaction = conn is not None
998+
if not conn:
993999
conn = yield from self.acquire()
994-
cursor = yield from conn.cursor(*args, **kwargs)
995-
996-
def release():
997-
cursor.close()
998-
self.pool.release(conn)
999-
cursor.release = release
1000-
else:
1001-
# Acquire cursor from provided connection, after cursor is
1002-
# released connection is NOT released to pool, i.e.
1003-
# for handling transactions:
1004-
1005-
cursor = yield from conn.cursor(*args, **kwargs)
1006-
cursor.release = lambda: cursor.close()
1007-
1000+
cursor = yield from conn.cursor(*args, **kwargs)
1001+
# NOTE: `cursor.release` is an awaitable object!
1002+
cursor.release = self.release_cursor(
1003+
cursor, in_transaction=in_transaction)
10081004
return cursor
10091005

10101006
@asyncio.coroutine
1011-
def close(self):
1012-
"""Terminate all pool connections.
1007+
def release_cursor(self, cursor, in_transaction=False):
1008+
"""Release cursor coroutine. Unless in transaction,
1009+
the connection is also released back to the pool.
10131010
"""
1014-
self.pool.terminate()
1015-
yield from self.pool.wait_closed()
1011+
conn = cursor.connection
1012+
cursor.close()
1013+
if not in_transaction:
1014+
self.pool.release(conn)
10161015

10171016

10181017
class AsyncPostgresqlMixin(AsyncDatabase):
@@ -1143,37 +1142,35 @@ def connect(self):
11431142
connect_timeout=self.timeout,
11441143
**self.connect_kwargs)
11451144

1145+
@asyncio.coroutine
1146+
def close(self):
1147+
"""Terminate all pool connections.
1148+
"""
1149+
self.pool.terminate()
1150+
yield from self.pool.wait_closed()
1151+
11461152
@asyncio.coroutine
11471153
def cursor(self, conn=None, *args, **kwargs):
11481154
"""Get cursor for connection from pool.
11491155
"""
1150-
if conn is None:
1151-
# Acquire connection with cursor, once cursor is released
1152-
# connection is also released to pool:
1153-
1156+
in_transaction = conn is not None
1157+
if not conn:
11541158
conn = yield from self.acquire()
1155-
cursor = yield from conn.cursor(*args, **kwargs)
1156-
1157-
def release():
1158-
cursor.close()
1159-
self.pool.release(conn)
1160-
cursor.release = release
1161-
else:
1162-
# Acquire cursor from provided connection, after cursor is
1163-
# released connection is NOT released to pool, i.e.
1164-
# for handling transactions:
1165-
1166-
cursor = yield from conn.cursor(*args, **kwargs)
1167-
cursor.release = lambda: cursor.close()
1168-
1159+
cursor = yield from conn.cursor(*args, **kwargs)
1160+
# NOTE: `cursor.release` is an awaitable object!
1161+
cursor.release = self.release_cursor(
1162+
cursor, in_transaction=in_transaction)
11691163
return cursor
11701164

11711165
@asyncio.coroutine
1172-
def close(self):
1173-
"""Terminate all pool connections.
1166+
def release_cursor(self, cursor, in_transaction=False):
1167+
"""Release cursor coroutine. Unless in transaction,
1168+
the connection is also released back to the pool.
11741169
"""
1175-
self.pool.terminate()
1176-
yield from self.pool.wait_closed()
1170+
conn = cursor.connection
1171+
yield from cursor.close()
1172+
if not in_transaction:
1173+
self.pool.release(conn)
11771174

11781175

11791176
class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
@@ -1395,7 +1392,7 @@ def _run_sql(database, operation, *args, **kwargs):
13951392
try:
13961393
yield from cursor.execute(operation, *args, **kwargs)
13971394
except:
1398-
cursor.release()
1395+
yield from cursor.release
13991396
raise
14001397

14011398
return cursor

tests/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,20 @@ def test(objects):
431431

432432
self.run_with_managers(test)
433433

434+
def test_many_requests(self):
435+
@asyncio.coroutine
436+
def test(objects):
437+
max_connections = getattr(objects.database, 'max_connections', 0)
438+
text = "Test %s" % uuid.uuid4()
439+
obj = yield from objects.create(TestModel, text=text)
440+
n = 2 * max_connections # number of requests
441+
done, not_done = yield from asyncio.wait(
442+
[objects.get(TestModel, id=obj.id) for _ in range(n)],
443+
loop=self.loop)
444+
self.assertEqual(len(done), n)
445+
446+
self.run_with_managers(test)
447+
434448
def test_create_obj(self):
435449
@asyncio.coroutine
436450
def test(objects):

0 commit comments

Comments
 (0)