2626from reflex .base import Base
2727from reflex .constants import CompileVars , RouteVar , SocketEvent
2828from reflex .constants .state import FIELD_MARKER
29+ from reflex .environment import environment
2930from reflex .event import Event , EventHandler
3031from reflex .istate .manager import StateManager
3132from reflex .istate .manager .disk import StateManagerDisk
@@ -1740,6 +1741,10 @@ async def test_state_manager_modify_state(
17401741 complex_1 = state .complex [1 ]
17411742 assert isinstance (complex_1 , MutableProxy )
17421743 state .complex [3 ] = complex_1
1744+
1745+ if environment .REFLEX_OPLOCK_ENABLED .get ():
1746+ await state_manager .close ()
1747+
17431748 # lock should be dropped after exiting the context
17441749 if isinstance (state_manager , StateManagerRedis ):
17451750 assert (await state_manager .redis .get (f"{ token } _lock" )) is None
@@ -1783,6 +1788,9 @@ async def _coro():
17831788 for f in asyncio .as_completed (tasks ):
17841789 await f
17851790
1791+ if environment .REFLEX_OPLOCK_ENABLED .get ():
1792+ await state_manager .close ()
1793+
17861794 assert (await state_manager .get_state (substate_token )).num1 == exp_num1
17871795
17881796 if isinstance (state_manager , StateManagerRedis ):
@@ -1837,12 +1845,35 @@ async def test_state_manager_lock_expire(
18371845 state_manager_redis .lock_expiration = LOCK_EXPIRATION
18381846 state_manager_redis .lock_warning_threshold = LOCK_WARNING_THRESHOLD
18391847
1848+ loop_exception = None
1849+
1850+ def loop_exception_handler (loop , context ):
1851+ """Catch the LockExpiredError from the event loop.
1852+
1853+ Args:
1854+ loop: The event loop.
1855+ context: The exception context.
1856+ """
1857+ nonlocal loop_exception
1858+ loop_exception = context ["exception" ]
1859+
1860+ asyncio .get_event_loop ().set_exception_handler (loop_exception_handler )
1861+
18401862 async with state_manager_redis .modify_state (substate_token_redis ):
18411863 await asyncio .sleep (0.01 )
18421864
1843- with pytest . raises ( LockExpiredError ):
1865+ if environment . REFLEX_OPLOCK_ENABLED . get ( ):
18441866 async with state_manager_redis .modify_state (substate_token_redis ):
18451867 await asyncio .sleep (LOCK_EXPIRE_SLEEP )
1868+ await asyncio .sleep (LOCK_EXPIRE_SLEEP )
1869+ assert loop_exception is not None
1870+ with pytest .raises (LockExpiredError ):
1871+ raise loop_exception
1872+ else :
1873+ with pytest .raises (LockExpiredError ):
1874+ async with state_manager_redis .modify_state (substate_token_redis ):
1875+ await asyncio .sleep (LOCK_EXPIRE_SLEEP )
1876+ assert loop_exception is None
18461877
18471878
18481879@pytest .mark .asyncio
@@ -1862,6 +1893,20 @@ async def test_state_manager_lock_expire_contend(
18621893 state_manager_redis .lock_expiration = LOCK_EXPIRATION
18631894 state_manager_redis .lock_warning_threshold = LOCK_WARNING_THRESHOLD
18641895
1896+ loop_exception = None
1897+
1898+ def loop_exception_handler (loop , context ):
1899+ """Catch the LockExpiredError from the event loop.
1900+
1901+ Args:
1902+ loop: The event loop.
1903+ context: The exception context.
1904+ """
1905+ nonlocal loop_exception
1906+ loop_exception = context ["exception" ]
1907+
1908+ asyncio .get_event_loop ().set_exception_handler (loop_exception_handler )
1909+
18651910 order = []
18661911 waiter_event = asyncio .Event ()
18671912
@@ -1876,19 +1921,31 @@ async def _coro_waiter():
18761921 await waiter_event .wait ()
18771922 async with state_manager_redis .modify_state (substate_token_redis ) as state :
18781923 order .append ("waiter" )
1879- assert state .num1 != unexp_num1
18801924 state .num1 = exp_num1
18811925
18821926 tasks = [
18831927 asyncio .create_task (_coro_blocker ()),
18841928 asyncio .create_task (_coro_waiter ()),
18851929 ]
1886- with pytest .raises (LockExpiredError ):
1887- await tasks [0 ]
1888- await tasks [1 ]
1930+ if environment .REFLEX_OPLOCK_ENABLED .get ():
1931+ await tasks [0 ] # Doesn't raise during `modify_state`, only on exit
1932+ await tasks [1 ]
1933+ await asyncio .sleep (LOCK_EXPIRE_SLEEP )
1934+ assert loop_exception is not None
1935+ with pytest .raises (LockExpiredError ):
1936+ raise loop_exception
1937+ # In oplock mode, the blocker block's both updates
1938+ assert (await state_manager_redis .get_state (substate_token_redis )).num1 == 0
1939+ else :
1940+ with pytest .raises (LockExpiredError ):
1941+ await tasks [0 ]
1942+ await tasks [1 ]
1943+ assert loop_exception is None
1944+ assert (
1945+ await state_manager_redis .get_state (substate_token_redis )
1946+ ).num1 == exp_num1
18891947
18901948 assert order == ["blocker" , "waiter" ]
1891- assert (await state_manager_redis .get_state (substate_token_redis )).num1 == exp_num1
18921949
18931950
18941951@pytest .mark .asyncio
@@ -1923,8 +1980,12 @@ async def _coro_blocker():
19231980 ]
19241981
19251982 await tasks [0 ]
1926- console_warn .assert_called ()
1927- assert console_warn .call_count == 7
1983+ if environment .REFLEX_OPLOCK_ENABLED .get ():
1984+ # When Oplock is enabled, we don't warn when lock is held too long.
1985+ console_warn .assert_not_called ()
1986+ else :
1987+ console_warn .assert_called ()
1988+ assert console_warn .call_count == 7
19281989
19291990
19301991class CopyingAsyncMock (AsyncMock ):
@@ -2079,6 +2140,9 @@ async def test_state_proxy(
20792140 assert sp ._self_actx is None
20802141 assert sp .value2 == "42"
20812142
2143+ if environment .REFLEX_OPLOCK_ENABLED .get ():
2144+ await mock_app .state_manager .close ()
2145+
20822146 # Get the state from the state manager directly and check that the value is updated
20832147 gotten_state = await mock_app .state_manager .get_state (
20842148 _substate_key (grandchild_state .router .session .client_token , grandchild_state )
@@ -2289,6 +2353,9 @@ async def test_background_task_no_block(mock_app: rx.App, token: str):
22892353 await task
22902354 assert not mock_app ._background_tasks
22912355
2356+ if environment .REFLEX_OPLOCK_ENABLED .get ():
2357+ await mock_app .state_manager .close ()
2358+
22922359 exp_order = [
22932360 "background_task:start" ,
22942361 "other" ,
@@ -2385,6 +2452,9 @@ async def test_background_task_reset(mock_app: rx.App, token: str):
23852452 await task
23862453 assert not mock_app ._background_tasks
23872454
2455+ if environment .REFLEX_OPLOCK_ENABLED .get ():
2456+ await mock_app .state_manager .close ()
2457+
23882458 assert (
23892459 await mock_app .state_manager .get_state (
23902460 _substate_key (token , BackgroundTaskState )
0 commit comments