Skip to content

Commit e20cb81

Browse files
committed
Fix up unit tests for OPLOCK_ENABLED mode
1 parent e1ef249 commit e20cb81

File tree

4 files changed

+121
-20
lines changed

4 files changed

+121
-20
lines changed

.github/workflows/unit_tests.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ jobs:
6262
export PYTHONUNBUFFERED=1
6363
export REFLEX_REDIS_URL=redis://localhost:6379
6464
uv run pytest tests/units --cov --no-cov-on-fail --cov-report=
65+
- name: Run unit tests w/ redis and OPLOCK_ENABLED
66+
if: ${{ matrix.os == 'ubuntu-latest' }}
67+
run: |
68+
export PYTHONUNBUFFERED=1
69+
export REFLEX_REDIS_URL=redis://localhost:6379
70+
export REFLEX_OPLOCK_ENABLED=true
71+
uv run pytest tests/units --cov --no-cov-on-fail --cov-report=
6572
# Change to explicitly install v1 when reflex-hosting-cli is compatible with v2
6673
- name: Run unit tests w/ pydantic v1
6774
run: |

reflex/istate/manager/redis.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -603,19 +603,21 @@ async def do_flush() -> None:
603603
async with state_lock:
604604
# Write the state to redis while no one else can modify the cached copy.
605605
state = self._cached_states.pop(client_token, None)
606-
if state:
607-
if self._debug_enabled:
606+
try:
607+
if state:
608+
if self._debug_enabled:
609+
console.debug(
610+
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} flushing state"
611+
)
612+
await self.set_state(token, state, lock_id=lock_id, **context)
613+
finally:
614+
if (current_lease := self._local_leases.get(client_token)) is task:
615+
self._local_leases.pop(client_token, None)
616+
# TODO: clean up the cached states locks periodically
617+
elif self._debug_enabled:
608618
console.debug(
609-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} flushing state"
619+
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}."
610620
)
611-
await self.set_state(token, state, lock_id=lock_id, **context)
612-
if (current_lease := self._local_leases.get(client_token)) is task:
613-
self._local_leases.pop(client_token, None)
614-
# TODO: clean up the cached states locks periodically
615-
elif self._debug_enabled:
616-
console.debug(
617-
f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}."
618-
)
619621

620622
async def lease_breaker():
621623
cancelled_error: asyncio.CancelledError | None = None
@@ -638,6 +640,14 @@ async def lease_breaker():
638640
try:
639641
# Shield the flush from cancellation to ensure it always runs to completion.
640642
await asyncio.shield(do_flush())
643+
except Exception as e:
644+
# Propagate exception to the main loop, since we have nowhere to catch it.
645+
if not isinstance(e, asyncio.CancelledError):
646+
asyncio.get_running_loop().call_exception_handler({
647+
"message": "Exception in Redis State Manager lease breaker",
648+
"exception": e,
649+
})
650+
raise
641651
finally:
642652
# Re-raise any cancellation error after cleaning up.
643653
if cancelled_error is not None:
@@ -993,7 +1003,7 @@ async def _lock(self, token: str):
9931003
console.debug(
9941004
f"{SMR} [{time.monotonic() - start:.3f}] {lock_key.decode()} released by {lock_id.decode()}"
9951005
)
996-
else:
1006+
elif deleted_lock_id is not None:
9971007
# This can happen if the caller never tried to `set_state` before the lock expired and is a pretty bad bug.
9981008
console.warn(
9991009
f"{lock_key.decode()} was released by {lock_id.decode()}, but it belonged to {deleted_lock_id.decode()}. This is a bug."
@@ -1020,5 +1030,6 @@ async def close(self):
10201030
# Then cancel all outstanding leases and write the cached states to redis.
10211031
for lease_task in self._local_leases.values():
10221032
lease_task.cancel()
1033+
await asyncio.gather(*self._local_leases.values(), return_exceptions=True)
10231034
finally:
10241035
await self.redis.aclose(close_connection_pool=True)

tests/units/test_app.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from reflex.components.core.cond import Cond
3333
from reflex.components.radix.themes.typography.text import Text
3434
from reflex.constants.state import FIELD_MARKER
35+
from reflex.environment import environment
3536
from reflex.event import Event
3637
from reflex.istate.manager.disk import StateManagerDisk
3738
from reflex.istate.manager.memory import StateManagerMemory
@@ -990,6 +991,9 @@ def getlist(key: str):
990991
== StateUpdate(delta=delta, events=[], final=True).json() + "\n"
991992
)
992993

994+
if environment.REFLEX_OPLOCK_ENABLED.get():
995+
await app.state_manager.close()
996+
993997
current_state = await app.state_manager.get_state(_substate_key(token, state))
994998
state_dict = current_state.dict()[state.get_full_name()]
995999
assert state_dict["img_list" + FIELD_MARKER] == [
@@ -1296,6 +1300,9 @@ def _dynamic_state_event(name, val, **kwargs):
12961300
with pytest.raises(StopAsyncIteration):
12971301
await process_coro.__anext__()
12981302

1303+
if environment.REFLEX_OPLOCK_ENABLED.get():
1304+
await app.state_manager.close()
1305+
12991306
# check that router data was written to the state_manager store
13001307
state = await app.state_manager.get_state(substate_token)
13011308
assert state.dynamic == exp_val
@@ -1363,6 +1370,9 @@ def _dynamic_state_event(name, val, **kwargs):
13631370
await process_coro.__anext__()
13641371

13651372
prev_exp_val = exp_val
1373+
1374+
if environment.REFLEX_OPLOCK_ENABLED.get():
1375+
await app.state_manager.close()
13661376
state = await app.state_manager.get_state(substate_token)
13671377
assert state.loaded == len(exp_vals)
13681378
assert state.counter == len(exp_vals)
@@ -1403,6 +1413,9 @@ async def test_process_events(mocker: MockerFixture, token: str):
14031413
async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"):
14041414
pass
14051415

1416+
if environment.REFLEX_OPLOCK_ENABLED.get():
1417+
await app.state_manager.close()
1418+
14061419
assert (await app.state_manager.get_state(event.substate_token)).value == 5
14071420
assert app._postprocess.call_count == 6 # pyright: ignore [reportAttributeAccessIssue]
14081421

tests/units/test_state.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from reflex.base import Base
2727
from reflex.constants import CompileVars, RouteVar, SocketEvent
2828
from reflex.constants.state import FIELD_MARKER
29+
from reflex.environment import environment
2930
from reflex.event import Event, EventHandler
3031
from reflex.istate.manager import StateManager
3132
from reflex.istate.manager.disk import StateManagerDisk
@@ -1740,6 +1741,10 @@ async def test_state_manager_modify_state(
17401741
complex_1 = state.complex[1]
17411742
assert isinstance(complex_1, MutableProxy)
17421743
state.complex[3] = complex_1
1744+
1745+
if environment.REFLEX_OPLOCK_ENABLED.get():
1746+
await state_manager.close()
1747+
17431748
# lock should be dropped after exiting the context
17441749
if isinstance(state_manager, StateManagerRedis):
17451750
assert (await state_manager.redis.get(f"{token}_lock")) is None
@@ -1783,6 +1788,9 @@ async def _coro():
17831788
for f in asyncio.as_completed(tasks):
17841789
await f
17851790

1791+
if environment.REFLEX_OPLOCK_ENABLED.get():
1792+
await state_manager.close()
1793+
17861794
assert (await state_manager.get_state(substate_token)).num1 == exp_num1
17871795

17881796
if isinstance(state_manager, StateManagerRedis):
@@ -1837,12 +1845,35 @@ async def test_state_manager_lock_expire(
18371845
state_manager_redis.lock_expiration = LOCK_EXPIRATION
18381846
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
18391847

1848+
loop_exception = None
1849+
1850+
def loop_exception_handler(loop, context):
1851+
"""Catch the LockExpiredError from the event loop.
1852+
1853+
Args:
1854+
loop: The event loop.
1855+
context: The exception context.
1856+
"""
1857+
nonlocal loop_exception
1858+
loop_exception = context["exception"]
1859+
1860+
asyncio.get_event_loop().set_exception_handler(loop_exception_handler)
1861+
18401862
async with state_manager_redis.modify_state(substate_token_redis):
18411863
await asyncio.sleep(0.01)
18421864

1843-
with pytest.raises(LockExpiredError):
1865+
if environment.REFLEX_OPLOCK_ENABLED.get():
18441866
async with state_manager_redis.modify_state(substate_token_redis):
18451867
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
1868+
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
1869+
assert loop_exception is not None
1870+
with pytest.raises(LockExpiredError):
1871+
raise loop_exception
1872+
else:
1873+
with pytest.raises(LockExpiredError):
1874+
async with state_manager_redis.modify_state(substate_token_redis):
1875+
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
1876+
assert loop_exception is None
18461877

18471878

18481879
@pytest.mark.asyncio
@@ -1862,6 +1893,20 @@ async def test_state_manager_lock_expire_contend(
18621893
state_manager_redis.lock_expiration = LOCK_EXPIRATION
18631894
state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD
18641895

1896+
loop_exception = None
1897+
1898+
def loop_exception_handler(loop, context):
1899+
"""Catch the LockExpiredError from the event loop.
1900+
1901+
Args:
1902+
loop: The event loop.
1903+
context: The exception context.
1904+
"""
1905+
nonlocal loop_exception
1906+
loop_exception = context["exception"]
1907+
1908+
asyncio.get_event_loop().set_exception_handler(loop_exception_handler)
1909+
18651910
order = []
18661911
waiter_event = asyncio.Event()
18671912

@@ -1876,19 +1921,31 @@ async def _coro_waiter():
18761921
await waiter_event.wait()
18771922
async with state_manager_redis.modify_state(substate_token_redis) as state:
18781923
order.append("waiter")
1879-
assert state.num1 != unexp_num1
18801924
state.num1 = exp_num1
18811925

18821926
tasks = [
18831927
asyncio.create_task(_coro_blocker()),
18841928
asyncio.create_task(_coro_waiter()),
18851929
]
1886-
with pytest.raises(LockExpiredError):
1887-
await tasks[0]
1888-
await tasks[1]
1930+
if environment.REFLEX_OPLOCK_ENABLED.get():
1931+
await tasks[0] # Doesn't raise during `modify_state`, only on exit
1932+
await tasks[1]
1933+
await asyncio.sleep(LOCK_EXPIRE_SLEEP)
1934+
assert loop_exception is not None
1935+
with pytest.raises(LockExpiredError):
1936+
raise loop_exception
1937+
# In oplock mode, the blocker block's both updates
1938+
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == 0
1939+
else:
1940+
with pytest.raises(LockExpiredError):
1941+
await tasks[0]
1942+
await tasks[1]
1943+
assert loop_exception is None
1944+
assert (
1945+
await state_manager_redis.get_state(substate_token_redis)
1946+
).num1 == exp_num1
18891947

18901948
assert order == ["blocker", "waiter"]
1891-
assert (await state_manager_redis.get_state(substate_token_redis)).num1 == exp_num1
18921949

18931950

18941951
@pytest.mark.asyncio
@@ -1923,8 +1980,12 @@ async def _coro_blocker():
19231980
]
19241981

19251982
await tasks[0]
1926-
console_warn.assert_called()
1927-
assert console_warn.call_count == 7
1983+
if environment.REFLEX_OPLOCK_ENABLED.get():
1984+
# When Oplock is enabled, we don't warn when lock is held too long.
1985+
console_warn.assert_not_called()
1986+
else:
1987+
console_warn.assert_called()
1988+
assert console_warn.call_count == 7
19281989

19291990

19301991
class CopyingAsyncMock(AsyncMock):
@@ -2079,6 +2140,9 @@ async def test_state_proxy(
20792140
assert sp._self_actx is None
20802141
assert sp.value2 == "42"
20812142

2143+
if environment.REFLEX_OPLOCK_ENABLED.get():
2144+
await mock_app.state_manager.close()
2145+
20822146
# Get the state from the state manager directly and check that the value is updated
20832147
gotten_state = await mock_app.state_manager.get_state(
20842148
_substate_key(grandchild_state.router.session.client_token, grandchild_state)
@@ -2289,6 +2353,9 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
22892353
await task
22902354
assert not mock_app._background_tasks
22912355

2356+
if environment.REFLEX_OPLOCK_ENABLED.get():
2357+
await mock_app.state_manager.close()
2358+
22922359
exp_order = [
22932360
"background_task:start",
22942361
"other",
@@ -2385,6 +2452,9 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
23852452
await task
23862453
assert not mock_app._background_tasks
23872454

2455+
if environment.REFLEX_OPLOCK_ENABLED.get():
2456+
await mock_app.state_manager.close()
2457+
23882458
assert (
23892459
await mock_app.state_manager.get_state(
23902460
_substate_key(token, BackgroundTaskState)

0 commit comments

Comments
 (0)