Skip to content

Commit be5d2e8

Browse files
committed
Refactored bg scheduler
1 parent 2672fce commit be5d2e8

File tree

5 files changed

+98
-17
lines changed

5 files changed

+98
-17
lines changed

redis/asyncio/multidb/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Union
44

5-
from redis.asyncio.client import PSWorkerThreadExcHandlerT, PubSubHandler
5+
from redis.asyncio.client import PubSubHandler
66
from redis.asyncio.multidb.command_executor import DefaultCommandExecutor
77
from redis.asyncio.multidb.config import DEFAULT_GRACE_PERIOD, MultiDbConfig
88
from redis.asyncio.multidb.database import AsyncDatabase, Databases
@@ -507,7 +507,7 @@ async def get_message(
507507
async def run(
508508
self,
509509
*,
510-
exception_handler: Optional["PSWorkerThreadExcHandlerT"] = None,
510+
exception_handler=None,
511511
poll_timeout: float = 1.0,
512512
) -> None:
513513
"""Process pub/sub messages using registered callbacks.
@@ -524,5 +524,5 @@ async def run(
524524
>>> await task
525525
"""
526526
return await self._client.command_executor.execute_pubsub_run(
527-
exception_handler=exception_handler, sleep_time=poll_timeout, pubsub=self
527+
sleep_time=poll_timeout, exception_handler=exception_handler, pubsub=self
528528
)

redis/asyncio/multidb/command_executor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,15 @@ async def callback():
266266

267267
return await self._execute_with_failure_detection(callback, *args)
268268

269-
async def execute_pubsub_run(self, sleep_time: float, **kwargs) -> Any:
269+
async def execute_pubsub_run(
270+
self, sleep_time: float, exception_handler=None, pubsub=None
271+
) -> Any:
270272
async def callback():
271-
return await self._active_pubsub.run(poll_timeout=sleep_time, **kwargs)
273+
return await self._active_pubsub.run(
274+
poll_timeout=sleep_time,
275+
exception_handler=exception_handler,
276+
pubsub=pubsub,
277+
)
272278

273279
return await self._execute_with_failure_detection(callback)
274280

redis/background.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,47 @@ class BackgroundScheduler:
1010

1111
def __init__(self):
1212
self._next_timer = None
13+
self._event_loops = []
14+
self._lock = threading.Lock()
15+
self._stopped = False
1316

1417
def __del__(self):
15-
if self._next_timer:
16-
self._next_timer.cancel()
18+
self.stop()
19+
20+
def stop(self):
21+
"""
22+
Stop all scheduled tasks and clean up resources.
23+
"""
24+
with self._lock:
25+
if self._stopped:
26+
return
27+
self._stopped = True
28+
29+
if self._next_timer:
30+
self._next_timer.cancel()
31+
self._next_timer = None
32+
33+
# Stop all event loops
34+
for loop in self._event_loops:
35+
if loop.is_running():
36+
loop.call_soon_threadsafe(loop.stop)
37+
38+
self._event_loops.clear()
1739

1840
def run_once(self, delay: float, callback: Callable, *args):
1941
"""
2042
Runs callable task once after certain delay in seconds.
2143
"""
44+
with self._lock:
45+
if self._stopped:
46+
return
47+
2248
# Run loop in a separate thread to unblock main thread.
2349
loop = asyncio.new_event_loop()
50+
51+
with self._lock:
52+
self._event_loops.append(loop)
53+
2454
thread = threading.Thread(
2555
target=_start_event_loop_in_thread,
2656
args=(loop, self._call_later, delay, callback, *args),
@@ -32,9 +62,16 @@ def run_recurring(self, interval: float, callback: Callable, *args):
3262
"""
3363
Runs recurring callable task with given interval in seconds.
3464
"""
65+
with self._lock:
66+
if self._stopped:
67+
return
68+
3569
# Run loop in a separate thread to unblock main thread.
3670
loop = asyncio.new_event_loop()
3771

72+
with self._lock:
73+
self._event_loops.append(loop)
74+
3875
thread = threading.Thread(
3976
target=_start_event_loop_in_thread,
4077
args=(loop, self._call_later_recurring, interval, callback, *args),
@@ -49,10 +86,17 @@ async def run_recurring_async(
4986
Runs recurring coroutine with given interval in seconds in the current event loop.
5087
To be used only from an async context. No additional threads are created.
5188
"""
89+
with self._lock:
90+
if self._stopped:
91+
return
92+
5293
loop = asyncio.get_running_loop()
5394
wrapped = _async_to_sync_wrapper(loop, coro, *args)
5495

5596
def tick():
97+
with self._lock:
98+
if self._stopped:
99+
return
56100
# Schedule the coroutine
57101
wrapped()
58102
# Schedule next tick
@@ -64,6 +108,9 @@ def tick():
64108
def _call_later(
65109
self, loop: asyncio.AbstractEventLoop, delay: float, callback: Callable, *args
66110
):
111+
with self._lock:
112+
if self._stopped:
113+
return
67114
self._next_timer = loop.call_later(delay, callback, *args)
68115

69116
def _call_later_recurring(
@@ -73,6 +120,9 @@ def _call_later_recurring(
73120
callback: Callable,
74121
*args,
75122
):
123+
with self._lock:
124+
if self._stopped:
125+
return
76126
self._call_later(
77127
loop, interval, self._execute_recurring, loop, interval, callback, *args
78128
)
@@ -87,7 +137,19 @@ def _execute_recurring(
87137
"""
88138
Executes recurring callable task with given interval in seconds.
89139
"""
90-
callback(*args)
140+
with self._lock:
141+
if self._stopped:
142+
return
143+
144+
try:
145+
callback(*args)
146+
except Exception:
147+
# Silently ignore exceptions during shutdown
148+
pass
149+
150+
with self._lock:
151+
if self._stopped:
152+
return
91153

92154
self._call_later(
93155
loop, interval, self._execute_recurring, loop, interval, callback, *args
@@ -106,7 +168,22 @@ def _start_event_loop_in_thread(
106168
"""
107169
asyncio.set_event_loop(event_loop)
108170
event_loop.call_soon(call_soon_cb, event_loop, *args)
109-
event_loop.run_forever()
171+
try:
172+
event_loop.run_forever()
173+
finally:
174+
try:
175+
# Clean up pending tasks
176+
pending = asyncio.all_tasks(event_loop)
177+
for task in pending:
178+
task.cancel()
179+
# Run loop once more to process cancellations
180+
event_loop.run_until_complete(
181+
asyncio.gather(*pending, return_exceptions=True)
182+
)
183+
except Exception:
184+
pass
185+
finally:
186+
event_loop.close()
110187

111188

112189
def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs):

redis/event.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,24 +117,22 @@ async def dispatch_async(self, event: object):
117117

118118
def register_listeners(
119119
self,
120-
event_listeners: Dict[
120+
mappings: Dict[
121121
Type[object],
122122
List[Union[EventListenerInterface, AsyncEventListenerInterface]],
123123
],
124124
):
125125
with self._lock:
126-
for event_type in event_listeners:
126+
for event_type in mappings:
127127
if event_type in self._event_listeners_mapping:
128128
self._event_listeners_mapping[event_type] = list(
129129
set(
130130
self._event_listeners_mapping[event_type]
131-
+ event_listeners[event_type]
131+
+ mappings[event_type]
132132
)
133133
)
134134
else:
135-
self._event_listeners_mapping[event_type] = event_listeners[
136-
event_type
137-
]
135+
self._event_listeners_mapping[event_type] = mappings[event_type]
138136

139137

140138
class AfterConnectionReleasedEvent:

tests/test_event.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def callback(event):
3131
mock_another_event_listener = Mock(spec=EventListenerInterface)
3232
mock_another_event_listener.listen = callback
3333
dispatcher.register_listeners(
34-
event_listeners={type(mock_event): [mock_another_event_listener]}
34+
mappings={type(mock_event): [mock_another_event_listener]}
3535
)
3636
dispatcher.dispatch(mock_event)
3737

@@ -60,7 +60,7 @@ async def callback(event):
6060
mock_another_event_listener = Mock(spec=AsyncEventListenerInterface)
6161
mock_another_event_listener.listen = callback
6262
dispatcher.register_listeners(
63-
event_listeners={type(mock_event): [mock_another_event_listener]}
63+
mappings={type(mock_event): [mock_another_event_listener]}
6464
)
6565
await dispatcher.dispatch_async(mock_event)
6666

0 commit comments

Comments
 (0)