Skip to content

Commit 70269ab

Browse files
committed
Allow the Manager class to be initialized with peewee.Proxy, see pytest-dev#28
1 parent 700b971 commit 70269ab

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

peewee_async.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ def __init__(self, database=None, *, loop=None):
109109

110110
self.loop = loop or asyncio.get_event_loop()
111111
self.database = database or self.database
112-
self.database.loop = self.loop
112+
attach_callback = getattr(self.database, 'attach_callback', None)
113+
if attach_callback:
114+
attach_callback(lambda db: db.set_event_loop(self.loop))
115+
else:
116+
self.database.set_event_loop(self.loop)
113117

114118
@property
115119
def is_connected(self):
@@ -819,6 +823,22 @@ class AsyncDatabase:
819823
_async_wait = None # connection waiter
820824
_task_data = None # task context data
821825

826+
def set_event_loop(self, loop):
827+
"""Set event loop for the database. Usually, you don't need to
828+
call this directly. It's called from `Manager.connect()` or
829+
`.connect_async()` methods.
830+
"""
831+
# These checks are not very pythonic, but I believe it's OK to be
832+
# a little paranoid about mismatching of asyncio event loops,
833+
# because such errors won't show clear traceback and could be
834+
# tricky to debug.
835+
loop = loop or asyncio.get_event_loop()
836+
if not self.loop:
837+
self.loop = loop
838+
elif self.loop != loop:
839+
raise RuntimeError("Error, the event loop is already set before. "
840+
"Make sure you're using the same event loop!")
841+
822842
@asyncio.coroutine
823843
def connect_async(self, loop=None, timeout=None):
824844
"""Set up async connection on specified event loop or
@@ -833,7 +853,7 @@ def connect_async(self, loop=None, timeout=None):
833853
elif self._async_wait:
834854
yield from self._async_wait
835855
else:
836-
self.loop = loop or asyncio.get_event_loop()
856+
self.set_event_loop(loop)
837857
self._async_wait = asyncio.Future(loop=self.loop)
838858

839859
conn = self._async_conn_cls(

tests/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,29 @@ def test_deferred_init(self):
265265
TestModel.create_table(True)
266266
TestModel.drop_table(True)
267267

268+
def test_proxy_database(self):
269+
loop = asyncio.new_event_loop()
270+
database = peewee.Proxy()
271+
TestModel._meta.database = database
272+
objects = peewee_async.Manager(database, loop=loop)
273+
274+
@asyncio.coroutine
275+
def test(objects):
276+
text = "Test %s" % uuid.uuid4()
277+
yield from objects.create(TestModel, text=text)
278+
279+
config = dict(defaults)
280+
for k in list(config.keys()):
281+
config[k].update(overrides.get(k, {}))
282+
database.initialize(db_classes[k](**config[k]))
283+
284+
TestModel.create_table(True)
285+
loop.run_until_complete(test(objects))
286+
loop.run_until_complete(objects.close())
287+
TestModel.drop_table(True)
288+
289+
loop.close()
290+
268291

269292
class OlderTestCase(unittest.TestCase):
270293
# only = ['postgres', 'postgres-ext', 'postgres-pool', 'postgres-pool-ext']

0 commit comments

Comments
 (0)