Skip to content

Commit d8dda8b

Browse files
committed
TaskLocals optimization: don't create data for every task when getting values; raise RuntimeError if transaction runs without task context
1 parent ebe362e commit d8dda8b

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

peewee_async.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,18 +1294,12 @@ def rollback(self, begin=True):
12941294
yield from _run_sql(self.db, 'BEGIN')
12951295

12961296
@asyncio.coroutine
1297-
def begin(self):
1297+
def __aenter__(self):
1298+
if not asyncio.Task.current_task(loop=self.loop):
1299+
raise RuntimeError("The transaction must run within a task")
12981300
yield from self.db.push_transaction_async()
1299-
13001301
if self.db.transaction_depth_async() == 1:
13011302
yield from _run_sql(self.db, 'BEGIN')
1302-
1303-
@asyncio.coroutine
1304-
def __aenter__(self):
1305-
if asyncio.Task.current_task(loop=self.loop):
1306-
yield from self.begin()
1307-
else:
1308-
yield from self.loop.create_task(self.begin())
13091303
return self
13101304

13111305
@asyncio.coroutine
@@ -1440,24 +1434,29 @@ def get(self, key, *val):
14401434
def set(self, key, val):
14411435
"""Set value stored for current running task.
14421436
"""
1443-
data = self.get_data()
1437+
data = self.get_data(True)
14441438
if data is not None:
14451439
data[key] = val
14461440
else:
14471441
raise RuntimeError("No task is currently running")
14481442

1449-
def get_data(self):
1450-
"""Get dict stored for current running task.
1443+
def get_data(self, create=False):
1444+
"""Get dict stored for current running task. Return `None`
1445+
or an empty dict if no data was found depending on the
1446+
`create` argument value.
1447+
1448+
:param create: if argument is `True`, create empty dict
1449+
for task, default: `False`
14511450
"""
14521451
task = asyncio.Task.current_task(loop=self.loop)
14531452
if task:
14541453
task_id = id(task)
1455-
if not task_id in self.data:
1454+
if create and not task_id in self.data:
14561455
self.data[task_id] = {}
1457-
task.add_done_callback(self.pop_data)
1458-
return self.data[task_id]
1456+
task.add_done_callback(self.del_data)
1457+
return self.data.get(task_id)
14591458

1460-
def pop_data(self, task):
1459+
def del_data(self, task):
14611460
"""Delete data for task from stored data dict.
14621461
"""
1463-
self.data.pop(id(task), None)
1462+
del self.data[id(task)]

0 commit comments

Comments
 (0)