22
33import asyncio
44import json
5+ import time
56from collections .abc import Callable , Generator
67from contextlib import asynccontextmanager
78from unittest .mock import AsyncMock , Mock , patch
@@ -604,6 +605,18 @@ async def test_redis_token_manager_get_token_owner(
604605 assert await manager2 ._get_token_owner ("token2" ) == manager2 .instance_id
605606
606607
608+ async def _wait_for_call_count_positive (mock : Mock , timeout : float = 5.0 ):
609+ """Wait until the mock's call count is positive.
610+
611+ Args:
612+ mock: The mock to wait on.
613+ timeout: The maximum time to wait in seconds.
614+ """
615+ deadline = time .monotonic () + timeout
616+ while mock .call_count == 0 and time .monotonic () < deadline : # noqa: ASYNC110
617+ await asyncio .sleep (0.1 )
618+
619+
607620@pytest .mark .usefixtures ("redis_url" )
608621@pytest .mark .asyncio
609622async def test_redis_token_manager_lost_and_found (
@@ -623,11 +636,13 @@ async def test_redis_token_manager_lost_and_found(
623636 await event_namespace2 .on_connect (sid = "sid2" , environ = query_string_for ("token2" ))
624637
625638 await event_namespace2 .emit_update (StateUpdate (), token = "token1" )
639+ await _wait_for_call_count_positive (emit1_mock )
626640 emit2_mock .assert_not_called ()
627641 emit1_mock .assert_called_once ()
628642 emit1_mock .reset_mock ()
629643
630644 await event_namespace2 .emit_update (StateUpdate (), token = "token2" )
645+ await _wait_for_call_count_positive (emit2_mock )
631646 emit1_mock .assert_not_called ()
632647 emit2_mock .assert_called_once ()
633648 emit2_mock .reset_mock ()
@@ -636,11 +651,13 @@ async def test_redis_token_manager_lost_and_found(
636651 await task
637652 await event_namespace2 .emit_update (StateUpdate (), token = "token1" )
638653 # Update should be dropped on the floor.
654+ await asyncio .sleep (2 )
639655 emit1_mock .assert_not_called ()
640656 emit2_mock .assert_not_called ()
641657
642658 await event_namespace2 .on_connect (sid = "sid1" , environ = query_string_for ("token1" ))
643659 await event_namespace2 .emit_update (StateUpdate (), token = "token1" )
660+ await _wait_for_call_count_positive (emit2_mock )
644661 emit1_mock .assert_not_called ()
645662 emit2_mock .assert_called_once ()
646663 emit2_mock .reset_mock ()
@@ -649,6 +666,7 @@ async def test_redis_token_manager_lost_and_found(
649666 await task
650667 await event_namespace1 .on_connect (sid = "sid1" , environ = query_string_for ("token1" ))
651668 await event_namespace2 .emit_update (StateUpdate (), token = "token1" )
669+ await _wait_for_call_count_positive (emit1_mock )
652670 emit2_mock .assert_not_called ()
653671 emit1_mock .assert_called_once ()
654672 emit1_mock .reset_mock ()
0 commit comments