Skip to content

Commit 6aa87be

Browse files
Bubble state out of heartbeat
1 parent 11fcf17 commit 6aa87be

File tree

1 file changed

+63
-22
lines changed

1 file changed

+63
-22
lines changed

src/replit_river/session.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import enum
33
import logging
4-
from typing import Any, Callable, Coroutine
4+
from typing import Any, Awaitable, Callable, Coroutine, Protocol
55

66
import nanoid # type: ignore
77
import websockets
@@ -42,6 +42,19 @@
4242
trace_setter = TransportMessageTracingSetter()
4343

4444

45+
class SendMessage(Protocol):
46+
async def __call__(
47+
self,
48+
*,
49+
stream_id: str,
50+
payload: dict[Any, Any] | str,
51+
control_flags: int,
52+
service_name: str | None,
53+
procedure_name: str | None,
54+
span: Span | None,
55+
) -> None: ...
56+
57+
4558
class SessionState(enum.Enum):
4659
"""The state a session can be in.
4760
@@ -53,7 +66,7 @@ class SessionState(enum.Enum):
5366
CLOSED = 2
5467

5568

56-
class Session(object):
69+
class Session:
5770
"""A transport object that handles the websocket connection with a client."""
5871

5972
def __init__(
@@ -106,7 +119,29 @@ def __init__(
106119
self._setup_heartbeats_task()
107120

108121
def _setup_heartbeats_task(self) -> None:
109-
self._task_manager.create_task(self._heartbeat())
122+
async def do_close_websocket() -> None:
123+
await self.close_websocket(
124+
self._ws_wrapper,
125+
should_retry=not self._is_server,
126+
)
127+
await self._begin_close_session_countdown()
128+
129+
def increment_and_get_heartbeat_misses() -> int:
130+
self._heartbeat_misses += 1
131+
return self._heartbeat_misses
132+
133+
self._task_manager.create_task(
134+
self._heartbeat(
135+
self.session_id,
136+
self._transport_options.heartbeat_ms,
137+
self._transport_options.heartbeats_until_dead,
138+
lambda: self._state,
139+
lambda: self._close_session_after_time_secs,
140+
close_websocket=do_close_websocket,
141+
send_message=self.send_message,
142+
increment_and_get_heartbeat_misses=increment_and_get_heartbeat_misses,
143+
)
144+
)
110145
self._task_manager.create_task(self._check_to_close_session())
111146

112147
async def is_session_open(self) -> bool:
@@ -276,45 +311,51 @@ async def _check_to_close_session(self) -> None:
276311

277312
async def _heartbeat(
278313
self,
314+
session_id: str,
315+
heartbeat_ms: float,
316+
heartbeats_until_dead: int,
317+
get_state: Callable[[], SessionState],
318+
get_closing_grace_period: Callable[[], float | None],
319+
close_websocket: Callable[[], Awaitable[None]],
320+
send_message: SendMessage,
321+
increment_and_get_heartbeat_misses: Callable[[], int],
279322
) -> None:
280323
logger.debug("Start heartbeat")
281324
while True:
282-
await asyncio.sleep(self._transport_options.heartbeat_ms / 1000)
283-
if self._state != SessionState.ACTIVE:
325+
await asyncio.sleep(heartbeat_ms / 1000)
326+
state = get_state()
327+
if state != SessionState.ACTIVE:
284328
logger.debug(
285329
"Session is closed, no need to send heartbeat, state : "
286330
"%r close_session_after_this: %r",
287-
{self._state},
288-
{self._close_session_after_time_secs},
331+
{state},
332+
{get_closing_grace_period()},
289333
)
290334
# session is closing / closed, no need to send heartbeat anymore
291335
return
292336
try:
293-
await self.send_message(
294-
"heartbeat",
337+
await send_message(
338+
stream_id="heartbeat",
295339
# TODO: make this a message class
296340
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
297-
{
341+
payload={
298342
"ack": 0,
299343
},
300-
ACK_BIT,
344+
control_flags=ACK_BIT,
345+
procedure_name=None,
346+
service_name=None,
347+
span=None,
301348
)
302-
self._heartbeat_misses += 1
303-
if (
304-
self._heartbeat_misses
305-
> self._transport_options.heartbeats_until_dead
306-
):
307-
if self._close_session_after_time_secs is not None:
349+
350+
if increment_and_get_heartbeat_misses() > heartbeats_until_dead:
351+
if get_closing_grace_period() is not None:
308352
# already in grace period, no need to set again
309353
continue
310354
logger.info(
311355
"%r closing websocket because of heartbeat misses",
312-
self.session_id,
356+
session_id,
313357
)
314-
await self.close_websocket(
315-
self._ws_wrapper, should_retry=not self._is_server
316-
)
317-
await self._begin_close_session_countdown()
358+
await close_websocket()
318359
continue
319360
except FailedSendingMessageException:
320361
# this is expected during websocket closed period

0 commit comments

Comments
 (0)