|
1 | 1 | import asyncio |
2 | | -import enum |
3 | 2 | import logging |
4 | | -from typing import Any, Awaitable, Callable, Coroutine, Protocol |
| 3 | +from typing import Any, Callable, Coroutine |
5 | 4 |
|
6 | 5 | import nanoid # type: ignore |
7 | 6 | import websockets |
|
10 | 9 | from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator |
11 | 10 | from websockets.exceptions import ConnectionClosed |
12 | 11 |
|
| 12 | +from replit_river.common_session import SessionState, setup_heartbeat |
13 | 13 | from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError |
14 | 14 | from replit_river.messages import ( |
15 | 15 | FailedSendingMessageException, |
|
42 | 42 | trace_setter = TransportMessageTracingSetter() |
43 | 43 |
|
44 | 44 |
|
45 | | -class SendMessage(Protocol): |
46 | | - async def __call__( |
47 | | - self, |
48 | | - *, |
49 | | - stream_id: str, |
50 | | - payload: dict[Any, Any] | str, |
51 | | - control_flags: int, |
52 | | - service_name: str | None, |
53 | | - procedure_name: str | None, |
54 | | - span: Span | None, |
55 | | - ) -> None: ... |
56 | | - |
57 | | - |
58 | | -class SessionState(enum.Enum): |
59 | | - """The state a session can be in. |
60 | | -
|
61 | | - Can only transition from ACTIVE to CLOSING to CLOSED. |
62 | | - """ |
63 | | - |
64 | | - ACTIVE = 0 |
65 | | - CLOSING = 1 |
66 | | - CLOSED = 2 |
67 | | - |
68 | | - |
69 | 45 | class Session: |
70 | 46 | """A transport object that handles the websocket connection with a client.""" |
71 | 47 |
|
@@ -131,7 +107,7 @@ def increment_and_get_heartbeat_misses() -> int: |
131 | 107 | return self._heartbeat_misses |
132 | 108 |
|
133 | 109 | self._task_manager.create_task( |
134 | | - self._heartbeat( |
| 110 | + setup_heartbeat( |
135 | 111 | self.session_id, |
136 | 112 | self._transport_options.heartbeat_ms, |
137 | 113 | self._transport_options.heartbeats_until_dead, |
@@ -309,58 +285,6 @@ async def _check_to_close_session(self) -> None: |
309 | 285 | await self.close() |
310 | 286 | return |
311 | 287 |
|
312 | | - async def _heartbeat( |
313 | | - self, |
314 | | - session_id: str, |
315 | | - heartbeat_ms: float, |
316 | | - heartbeats_until_dead: int, |
317 | | - get_state: Callable[[], SessionState], |
318 | | - get_closing_grace_period: Callable[[], float | None], |
319 | | - close_websocket: Callable[[], Awaitable[None]], |
320 | | - send_message: SendMessage, |
321 | | - increment_and_get_heartbeat_misses: Callable[[], int], |
322 | | - ) -> None: |
323 | | - logger.debug("Start heartbeat") |
324 | | - while True: |
325 | | - await asyncio.sleep(heartbeat_ms / 1000) |
326 | | - state = get_state() |
327 | | - if state != SessionState.ACTIVE: |
328 | | - logger.debug( |
329 | | - "Session is closed, no need to send heartbeat, state : " |
330 | | - "%r close_session_after_this: %r", |
331 | | - {state}, |
332 | | - {get_closing_grace_period()}, |
333 | | - ) |
334 | | - # session is closing / closed, no need to send heartbeat anymore |
335 | | - return |
336 | | - try: |
337 | | - await send_message( |
338 | | - stream_id="heartbeat", |
339 | | - # TODO: make this a message class |
340 | | - # https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42 |
341 | | - payload={ |
342 | | - "ack": 0, |
343 | | - }, |
344 | | - control_flags=ACK_BIT, |
345 | | - procedure_name=None, |
346 | | - service_name=None, |
347 | | - span=None, |
348 | | - ) |
349 | | - |
350 | | - if increment_and_get_heartbeat_misses() > heartbeats_until_dead: |
351 | | - if get_closing_grace_period() is not None: |
352 | | - # already in grace period, no need to set again |
353 | | - continue |
354 | | - logger.info( |
355 | | - "%r closing websocket because of heartbeat misses", |
356 | | - session_id, |
357 | | - ) |
358 | | - await close_websocket() |
359 | | - continue |
360 | | - except FailedSendingMessageException: |
361 | | - # this is expected during websocket closed period |
362 | | - continue |
363 | | - |
364 | 288 | async def _send_buffered_messages( |
365 | 289 | self, websocket: websockets.WebSocketCommonProtocol |
366 | 290 | ) -> None: |
|
0 commit comments