Skip to content

Commit e27368f

Browse files
Break out heartbeat lifecycle
1 parent 6aa87be commit e27368f

File tree

2 files changed

+90
-79
lines changed

2 files changed

+90
-79
lines changed

src/replit_river/common_session.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import asyncio
2+
import enum
3+
import logging
4+
from typing import Any, Awaitable, Callable, Protocol
5+
6+
from opentelemetry.trace import Span
7+
8+
from replit_river.messages import FailedSendingMessageException
9+
from replit_river.rpc import ACK_BIT
10+
11+
12+
logger = logging.getLogger(__name__)
13+
14+
class SendMessage(Protocol):
15+
async def __call__(
16+
self,
17+
*,
18+
stream_id: str,
19+
payload: dict[Any, Any] | str,
20+
control_flags: int,
21+
service_name: str | None,
22+
procedure_name: str | None,
23+
span: Span | None,
24+
) -> None: ...
25+
26+
27+
class SessionState(enum.Enum):
28+
"""The state a session can be in.
29+
30+
Can only transition from ACTIVE to CLOSING to CLOSED.
31+
"""
32+
33+
ACTIVE = 0
34+
CLOSING = 1
35+
CLOSED = 2
36+
37+
38+
async def setup_heartbeat(
39+
session_id: str,
40+
heartbeat_ms: float,
41+
heartbeats_until_dead: int,
42+
get_state: Callable[[], SessionState],
43+
get_closing_grace_period: Callable[[], float | None],
44+
close_websocket: Callable[[], Awaitable[None]],
45+
send_message: SendMessage,
46+
increment_and_get_heartbeat_misses: Callable[[], int],
47+
) -> None:
48+
logger.debug("Start heartbeat")
49+
while True:
50+
await asyncio.sleep(heartbeat_ms / 1000)
51+
state = get_state()
52+
if state != SessionState.ACTIVE:
53+
logger.debug(
54+
"Session is closed, no need to send heartbeat, state : "
55+
"%r close_session_after_this: %r",
56+
{state},
57+
{get_closing_grace_period()},
58+
)
59+
# session is closing / closed, no need to send heartbeat anymore
60+
return
61+
try:
62+
await send_message(
63+
stream_id="heartbeat",
64+
# TODO: make this a message class
65+
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
66+
payload={
67+
"ack": 0,
68+
},
69+
control_flags=ACK_BIT,
70+
procedure_name=None,
71+
service_name=None,
72+
span=None,
73+
)
74+
75+
if increment_and_get_heartbeat_misses() > heartbeats_until_dead:
76+
if get_closing_grace_period() is not None:
77+
# already in grace period, no need to set again
78+
continue
79+
logger.info(
80+
"%r closing websocket because of heartbeat misses",
81+
session_id,
82+
)
83+
await close_websocket()
84+
continue
85+
except FailedSendingMessageException:
86+
# this is expected during websocket closed period
87+
continue

src/replit_river/session.py

Lines changed: 3 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
2-
import enum
32
import logging
4-
from typing import Any, Awaitable, Callable, Coroutine, Protocol
3+
from typing import Any, Callable, Coroutine
54

65
import nanoid # type: ignore
76
import websockets
@@ -10,6 +9,7 @@
109
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
1110
from websockets.exceptions import ConnectionClosed
1211

12+
from replit_river.common_session import SessionState, setup_heartbeat
1313
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
1414
from replit_river.messages import (
1515
FailedSendingMessageException,
@@ -42,30 +42,6 @@
4242
trace_setter = TransportMessageTracingSetter()
4343

4444

45-
class SendMessage(Protocol):
46-
async def __call__(
47-
self,
48-
*,
49-
stream_id: str,
50-
payload: dict[Any, Any] | str,
51-
control_flags: int,
52-
service_name: str | None,
53-
procedure_name: str | None,
54-
span: Span | None,
55-
) -> None: ...
56-
57-
58-
class SessionState(enum.Enum):
59-
"""The state a session can be in.
60-
61-
Can only transition from ACTIVE to CLOSING to CLOSED.
62-
"""
63-
64-
ACTIVE = 0
65-
CLOSING = 1
66-
CLOSED = 2
67-
68-
6945
class Session:
7046
"""A transport object that handles the websocket connection with a client."""
7147

@@ -131,7 +107,7 @@ def increment_and_get_heartbeat_misses() -> int:
131107
return self._heartbeat_misses
132108

133109
self._task_manager.create_task(
134-
self._heartbeat(
110+
setup_heartbeat(
135111
self.session_id,
136112
self._transport_options.heartbeat_ms,
137113
self._transport_options.heartbeats_until_dead,
@@ -309,58 +285,6 @@ async def _check_to_close_session(self) -> None:
309285
await self.close()
310286
return
311287

312-
async def _heartbeat(
313-
self,
314-
session_id: str,
315-
heartbeat_ms: float,
316-
heartbeats_until_dead: int,
317-
get_state: Callable[[], SessionState],
318-
get_closing_grace_period: Callable[[], float | None],
319-
close_websocket: Callable[[], Awaitable[None]],
320-
send_message: SendMessage,
321-
increment_and_get_heartbeat_misses: Callable[[], int],
322-
) -> None:
323-
logger.debug("Start heartbeat")
324-
while True:
325-
await asyncio.sleep(heartbeat_ms / 1000)
326-
state = get_state()
327-
if state != SessionState.ACTIVE:
328-
logger.debug(
329-
"Session is closed, no need to send heartbeat, state : "
330-
"%r close_session_after_this: %r",
331-
{state},
332-
{get_closing_grace_period()},
333-
)
334-
# session is closing / closed, no need to send heartbeat anymore
335-
return
336-
try:
337-
await send_message(
338-
stream_id="heartbeat",
339-
# TODO: make this a message class
340-
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
341-
payload={
342-
"ack": 0,
343-
},
344-
control_flags=ACK_BIT,
345-
procedure_name=None,
346-
service_name=None,
347-
span=None,
348-
)
349-
350-
if increment_and_get_heartbeat_misses() > heartbeats_until_dead:
351-
if get_closing_grace_period() is not None:
352-
# already in grace period, no need to set again
353-
continue
354-
logger.info(
355-
"%r closing websocket because of heartbeat misses",
356-
session_id,
357-
)
358-
await close_websocket()
359-
continue
360-
except FailedSendingMessageException:
361-
# this is expected during websocket closed period
362-
continue
363-
364288
async def _send_buffered_messages(
365289
self, websocket: websockets.WebSocketCommonProtocol
366290
) -> None:

0 commit comments

Comments
 (0)