Skip to content

Commit 75a865d

Browse files
committed
refactor: side effect runners always run the side effect in the event loop provided to them regardless of the return value of the side effect being a coroutine or not, this is because even if the side effect is not a coroutine, it might still use async features internally
1 parent 86f07ae commit 75a865d

File tree

7 files changed

+93
-88
lines changed

7 files changed

+93
-88
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Upcoming
44

55
- refactor: provide correct signature for the autorun instance based on the function it decorates
6+
- refactor: side effect runners always run the side effect in the event loop provided to them regardless of the return value of the side effect being a coroutine or not, this is because even if the side effect is not a coroutine, it might still use async features internally
67

78
## Version 0.18.3
89

redux/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,10 @@ def __init__(
9393
tuple[EventHandler[Event], Event] | None
9494
]()
9595
self._workers = [
96-
SideEffectRunnerThread(task_queue=self._event_handlers_queue)
96+
SideEffectRunnerThread(
97+
task_queue=self._event_handlers_queue,
98+
create_task=self._create_task,
99+
)
97100
for _ in range(self.store_options.threads)
98101
]
99102
for worker in self._workers:

redux/side_effect_runner.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import asyncio
66
import contextlib
7+
import inspect
78
import threading
89
import weakref
910
from asyncio import Handle, iscoroutine
1011
from collections.abc import Callable
11-
from inspect import signature
1212
from typing import TYPE_CHECKING, Any, Generic, cast
1313

14-
from redux.basic_types import Event, EventHandler
14+
from redux.basic_types import Event, EventHandler, TaskCreator
1515

1616
if TYPE_CHECKING:
1717
import queue
@@ -24,15 +24,14 @@ def __init__(
2424
self: SideEffectRunnerThread,
2525
*,
2626
task_queue: queue.Queue[tuple[EventHandler[Event], Event] | None],
27+
create_task: TaskCreator | None,
2728
) -> None:
2829
"""Initialize the side effect runner thread."""
2930
super().__init__()
3031
self.task_queue = task_queue
3132
self.loop = asyncio.get_event_loop()
3233
self._handles: set[Handle] = set()
33-
self.create_task = lambda coro: self._handles.add(
34-
self.loop.call_soon_threadsafe(self.loop.create_task, coro),
35-
)
34+
self.create_task = create_task
3635

3736
def run(self: SideEffectRunnerThread[Event]) -> None:
3837
"""Run the side effect runner thread."""
@@ -51,12 +50,27 @@ def run(self: SideEffectRunnerThread[Event]) -> None:
5150
event_handler = event_handler_
5251
parameters = 1
5352
with contextlib.suppress(Exception):
54-
parameters = len(signature(event_handler).parameters)
55-
if parameters == 1:
56-
result = cast(Callable[[Event], Any], event_handler)(event)
57-
else:
58-
result = cast(Callable[[], Any], event_handler)()
59-
if iscoroutine(result):
60-
self.create_task(result)
53+
parameters = len(inspect.signature(event_handler).parameters)
54+
55+
if self.create_task:
56+
57+
async def _(
58+
event_handler: EventHandler[Event],
59+
event: Event,
60+
parameters: int,
61+
) -> None:
62+
if parameters == 1:
63+
result = cast(Callable[[Event], Any], event_handler)(event)
64+
else:
65+
result = cast(Callable[[], Any], event_handler)()
66+
if iscoroutine(result):
67+
await result
68+
69+
self.create_task(_(event_handler, event, parameters))
70+
else: # noqa: PLR5501
71+
if parameters == 1:
72+
cast(Callable[[Event], Any], event_handler)(event)
73+
else:
74+
cast(Callable[[], Any], event_handler)()
6175
finally:
6276
self.task_queue.task_done()

redux_pytest/fixtures/event_loop.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,37 @@
88
import pytest
99

1010
if TYPE_CHECKING:
11-
from collections.abc import Coroutine
11+
from collections.abc import Callable, Coroutine
12+
13+
from redux.basic_types import TaskCreatorCallback
1214

1315

1416
class LoopThread(threading.Thread):
1517
def __init__(self: LoopThread) -> None:
1618
super().__init__()
1719
self.loop = asyncio.new_event_loop()
18-
asyncio.set_event_loop(self.loop)
1920

2021
def run(self: LoopThread) -> None:
2122
self.loop.run_forever()
2223

2324
def stop(self: LoopThread) -> None:
24-
asyncio.set_event_loop(None)
2525
self.loop.call_soon_threadsafe(self.loop.stop)
2626

27-
def create_task(self: LoopThread, coro: Coroutine) -> None:
28-
self.loop.call_soon_threadsafe(self.loop.create_task, coro)
27+
def create_task(
28+
self: LoopThread,
29+
coro: Coroutine,
30+
*,
31+
callback: TaskCreatorCallback | None = None,
32+
) -> None:
33+
def _(
34+
coro: Coroutine,
35+
callback: Callable[[asyncio.Task], None] | None = None,
36+
) -> None:
37+
task = self.loop.create_task(coro)
38+
if callback:
39+
task.add_done_callback(callback)
40+
41+
self.loop.call_soon_threadsafe(_, coro, callback)
2942

3043

3144
@pytest.fixture

tests/test_async.py

Lines changed: 33 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from redux.basic_types import (
1212
AutorunOptions,
1313
BaseAction,
14+
BaseEvent,
1415
CompleteReducerResult,
1516
CreateStoreOptions,
1617
FinishAction,
@@ -22,7 +23,7 @@
2223
from redux.main import Store
2324

2425
if TYPE_CHECKING:
25-
from collections.abc import Callable, Coroutine
26+
from collections.abc import Generator
2627

2728
from redux_pytest.fixtures.event_loop import LoopThread
2829

@@ -37,21 +38,28 @@ class StateType(Immutable):
3738
class IncrementAction(BaseAction): ...
3839

3940

41+
class IncrementEvent(BaseEvent):
42+
post_value: int
43+
44+
4045
class SetMirroredValueAction(BaseAction):
4146
value: int
4247

4348

4449
def reducer(
4550
state: StateType | None,
4651
action: Action,
47-
) -> StateType | CompleteReducerResult[StateType, Action, FinishEvent]:
52+
) -> StateType | CompleteReducerResult[StateType, Action, IncrementEvent]:
4853
if state is None:
4954
if isinstance(action, InitAction):
5055
return StateType(value=0, mirrored_value=0)
5156
raise InitializationActionError(action)
5257

5358
if isinstance(action, IncrementAction):
54-
return replace(state, value=state.value + 1)
59+
return CompleteReducerResult(
60+
state=replace(state, value=state.value + 1),
61+
events=[IncrementEvent(post_value=state.value + 1)],
62+
)
5563
if isinstance(action, SetMirroredValueAction):
5664
return replace(state, mirrored_value=action.value)
5765
return state
@@ -62,25 +70,16 @@ def reducer(
6270

6371

6472
@pytest.fixture
65-
def store(event_loop: LoopThread) -> StoreType:
66-
def _create_task_with_callback(
67-
coro: Coroutine,
68-
callback: Callable[[asyncio.Task], None] | None = None,
69-
) -> None:
70-
def create_task_with_callback() -> None:
71-
task = event_loop.loop.create_task(coro)
72-
if callback:
73-
callback(task)
74-
75-
event_loop.loop.call_soon_threadsafe(create_task_with_callback)
76-
77-
return Store(
73+
def store(event_loop: LoopThread) -> Generator[StoreType, None, None]:
74+
store = Store(
7875
reducer,
7976
options=CreateStoreOptions(
8077
auto_init=True,
81-
task_creator=_create_task_with_callback,
78+
task_creator=event_loop.create_task,
8279
),
8380
)
81+
yield store
82+
store.subscribe_event(FinishEvent, lambda: event_loop.stop())
8483

8584

8685
def dispatch_actions(store: StoreType) -> None:
@@ -90,7 +89,6 @@ def dispatch_actions(store: StoreType) -> None:
9089

9190
def test_autorun(
9291
store: StoreType,
93-
event_loop: LoopThread,
9492
) -> None:
9593
@store.autorun(lambda state: state.value)
9694
async def sync_mirror(value: int) -> int:
@@ -107,16 +105,12 @@ async def sync_mirror(value: int) -> int:
107105
def _(mirrored_value: int) -> None:
108106
if mirrored_value < INCREMENTS:
109107
return
110-
event_loop.stop()
111108
store.dispatch(FinishAction())
112109

113110
dispatch_actions(store)
114111

115112

116-
def test_autorun_autoawait(
117-
store: StoreType,
118-
event_loop: LoopThread,
119-
) -> None:
113+
def test_autorun_autoawait(store: StoreType) -> None:
120114
@store.autorun(lambda state: state.value, options=AutorunOptions(auto_await=False))
121115
async def sync_mirror(value: int) -> int:
122116
store.dispatch(SetMirroredValueAction(value=value))
@@ -139,14 +133,10 @@ async def _(values: tuple[int, int]) -> None:
139133
elif value < INCREMENTS:
140134
store.dispatch(IncrementAction())
141135
else:
142-
event_loop.stop()
143136
store.dispatch(FinishAction())
144137

145138

146-
def test_autorun_default_value(
147-
store: StoreType,
148-
event_loop: LoopThread,
149-
) -> None:
139+
def test_autorun_default_value(store: StoreType) -> None:
150140
@store.autorun(lambda state: state.value, options=AutorunOptions(default_value=5))
151141
async def _(value: int) -> int:
152142
store.dispatch(SetMirroredValueAction(value=value))
@@ -156,19 +146,16 @@ async def _(value: int) -> int:
156146
lambda state: state.mirrored_value,
157147
lambda state: state.mirrored_value >= INCREMENTS,
158148
)
159-
def _(mirrored_value: int) -> None:
149+
async def _(mirrored_value: int) -> None:
160150
if mirrored_value < INCREMENTS:
161151
return
162-
event_loop.stop()
152+
await asyncio.sleep(0.1)
163153
store.dispatch(FinishAction())
164154

165155
dispatch_actions(store)
166156

167157

168-
def test_view(
169-
store: StoreType,
170-
event_loop: LoopThread,
171-
) -> None:
158+
def test_view(store: StoreType) -> None:
172159
calls = []
173160

174161
@store.view(lambda state: state.value)
@@ -184,12 +171,11 @@ async def _(value: int) -> None:
184171
if value < INCREMENTS:
185172
store.dispatch(IncrementAction())
186173
else:
187-
event_loop.stop()
188174
store.dispatch(FinishAction())
189175
assert calls == list(range(INCREMENTS + 1))
190176

191177

192-
def test_view_await(store: StoreType, event_loop: LoopThread) -> None:
178+
def test_view_await(store: StoreType) -> None:
193179
calls = []
194180

195181
@store.view(lambda state: state.value)
@@ -208,15 +194,11 @@ async def _(value: int) -> None:
208194
if value < INCREMENTS:
209195
store.dispatch(IncrementAction())
210196
else:
211-
event_loop.stop()
212197
store.dispatch(FinishAction())
213198
assert calls == list(range(INCREMENTS + 1))
214199

215200

216-
def test_view_with_args(
217-
store: StoreType,
218-
event_loop: LoopThread,
219-
) -> None:
201+
def test_view_with_args(store: StoreType) -> None:
220202
calls = []
221203

222204
@store.view(lambda state: state.value)
@@ -231,15 +213,11 @@ async def _(value: int) -> None:
231213
if value < INCREMENTS:
232214
store.dispatch(IncrementAction())
233215
else:
234-
event_loop.stop()
235216
store.dispatch(FinishAction())
236217
assert calls == [j for i in list(range(INCREMENTS + 1)) for j in [i] * 2]
237218

238219

239-
def test_view_with_default_value(
240-
store: StoreType,
241-
event_loop: LoopThread,
242-
) -> None:
220+
def test_view_with_default_value(store: StoreType) -> None:
243221
calls = []
244222

245223
@store.view(lambda state: state.value, options=ViewOptions(default_value=5))
@@ -253,51 +231,39 @@ async def _(value: int) -> None:
253231
if value < INCREMENTS:
254232
store.dispatch(IncrementAction())
255233
else:
256-
event_loop.stop()
257234
store.dispatch(FinishAction())
258235
assert calls == list(range(INCREMENTS + 1))
259236

260237
store.dispatch(InitAction())
261238

262239

263-
def test_subscription(
264-
store: StoreType,
265-
event_loop: LoopThread,
266-
) -> None:
240+
def test_subscription(store: StoreType) -> None:
267241
async def render(state: StateType) -> None:
242+
await asyncio.sleep(0.1)
268243
if state.value == INCREMENTS:
269244
unsubscribe()
270245
store.dispatch(FinishAction())
271-
event_loop.stop()
272246

273247
unsubscribe = store.subscribe(render)
274248

275249
dispatch_actions(store)
276250

277251

278-
def test_event_subscription(
279-
store: StoreType,
280-
event_loop: LoopThread,
281-
) -> None:
282-
async def finish() -> None:
252+
def test_event_subscription(store: StoreType) -> None:
253+
async def handler(event: IncrementEvent) -> None:
283254
await asyncio.sleep(0.1)
284-
event_loop.stop()
255+
if event.post_value == INCREMENTS:
256+
unsubscribe()
257+
store.dispatch(FinishAction())
285258

286-
store.subscribe_event(FinishEvent, finish)
287-
store.dispatch(FinishAction())
259+
unsubscribe = store.subscribe_event(IncrementEvent, handler)
288260

289261
dispatch_actions(store)
290262

291263

292-
def test_event_subscription_with_no_task_creator(event_loop: LoopThread) -> None:
264+
def test_event_subscription_with_no_task_creator() -> None:
293265
store = Store(
294266
reducer,
295267
options=CreateStoreOptions(auto_init=True),
296268
)
297-
298-
async def finish() -> None:
299-
await asyncio.sleep(0.1)
300-
event_loop.stop()
301-
302-
store.subscribe_event(FinishEvent, finish)
303269
store.dispatch(FinishAction())

0 commit comments

Comments
 (0)