Skip to content

Commit b0f4553

Browse files
Moving setup_heartbeat and check_to_close_session over to v1 session.py
1 parent df0e0c4 commit b0f4553

File tree

2 files changed

+85
-84
lines changed

2 files changed

+85
-84
lines changed

src/replit_river/common_session.py

Lines changed: 2 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import enum
33
import logging
4-
from typing import Any, Awaitable, Callable, Coroutine, Protocol
4+
from typing import Any, Callable, Coroutine, Protocol
55

66
from opentelemetry.trace import Span
77
from websockets import WebSocketCommonProtocol
@@ -12,7 +12,7 @@
1212
WebsocketClosedException,
1313
send_transport_message,
1414
)
15-
from replit_river.rpc import ACK_BIT, TransportMessage
15+
from replit_river.rpc import TransportMessage
1616

1717
logger = logging.getLogger(__name__)
1818

@@ -52,86 +52,6 @@ class SessionState(enum.Enum):
5252
TerminalStates = set([SessionState.CLOSING, SessionState.CLOSED])
5353

5454

55-
async def setup_heartbeat(
56-
session_id: str,
57-
heartbeat_ms: float,
58-
heartbeats_until_dead: int,
59-
get_state: Callable[[], SessionState],
60-
get_closing_grace_period: Callable[[], float | None],
61-
close_websocket: Callable[[], Awaitable[None]],
62-
send_message: SendMessage,
63-
increment_and_get_heartbeat_misses: Callable[[], int],
64-
) -> None:
65-
while True:
66-
await asyncio.sleep(heartbeat_ms / 1000)
67-
state = get_state()
68-
if state == SessionState.CONNECTING:
69-
logger.debug("Websocket is not connected, not sending heartbeat")
70-
continue
71-
if state in TerminalStates:
72-
logger.debug(
73-
"Session is closed, no need to send heartbeat, state : "
74-
"%r close_session_after_this: %r",
75-
{state},
76-
{get_closing_grace_period()},
77-
)
78-
# session is closing / closed, no need to send heartbeat anymore
79-
return
80-
try:
81-
await send_message(
82-
stream_id="heartbeat",
83-
# TODO: make this a message class
84-
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
85-
payload={
86-
"ack": 0,
87-
},
88-
control_flags=ACK_BIT,
89-
procedure_name=None,
90-
service_name=None,
91-
span=None,
92-
)
93-
94-
if increment_and_get_heartbeat_misses() > heartbeats_until_dead:
95-
if get_closing_grace_period() is not None:
96-
# already in grace period, no need to set again
97-
continue
98-
logger.info(
99-
"%r closing websocket because of heartbeat misses",
100-
session_id,
101-
)
102-
await close_websocket()
103-
continue
104-
except FailedSendingMessageException:
105-
# this is expected during websocket closed period
106-
continue
107-
108-
109-
async def check_to_close_session(
110-
transport_id: str,
111-
close_session_check_interval_ms: float,
112-
get_state: Callable[[], SessionState],
113-
get_current_time: Callable[[], Awaitable[float]],
114-
get_close_session_after_time_secs: Callable[[], float | None],
115-
do_close: Callable[[], Awaitable[None]],
116-
) -> None:
117-
while True:
118-
await asyncio.sleep(close_session_check_interval_ms / 1000)
119-
if get_state() in TerminalStates:
120-
# already closing
121-
return
122-
# calculate the value now before comparing it so that there are no
123-
# await points between the check and the comparison to avoid a TOCTOU
124-
# race.
125-
current_time = await get_current_time()
126-
close_session_after_time_secs = get_close_session_after_time_secs()
127-
if not close_session_after_time_secs:
128-
continue
129-
if current_time > close_session_after_time_secs:
130-
logger.info("Grace period ended for %s, closing session", transport_id)
131-
await do_close()
132-
return
133-
134-
13555
async def buffered_message_sender(
13656
connection_condition: asyncio.Condition,
13757
message_enqueued: asyncio.Semaphore,

src/replit_river/session.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
1010

1111
from replit_river.common_session import (
12+
SendMessage,
1213
SessionState,
13-
check_to_close_session,
14-
setup_heartbeat,
14+
TerminalStates,
1515
)
1616
from replit_river.message_buffer import MessageBuffer, MessageBufferClosedError
1717
from replit_river.messages import (
@@ -27,6 +27,7 @@
2727
from replit_river.websocket_wrapper import WebsocketWrapper, WsState
2828

2929
from .rpc import (
30+
ACK_BIT,
3031
TransportMessage,
3132
TransportMessageTracingSetter,
3233
)
@@ -337,3 +338,83 @@ async def close(self) -> None:
337338
self._streams.clear()
338339

339340
self._state = SessionState.CLOSED
341+
342+
343+
async def check_to_close_session(
344+
transport_id: str,
345+
close_session_check_interval_ms: float,
346+
get_state: Callable[[], SessionState],
347+
get_current_time: Callable[[], Awaitable[float]],
348+
get_close_session_after_time_secs: Callable[[], float | None],
349+
do_close: Callable[[], Awaitable[None]],
350+
) -> None:
351+
while True:
352+
await asyncio.sleep(close_session_check_interval_ms / 1000)
353+
if get_state() in TerminalStates:
354+
# already closing
355+
return
356+
# calculate the value now before comparing it so that there are no
357+
# await points between the check and the comparison to avoid a TOCTOU
358+
# race.
359+
current_time = await get_current_time()
360+
close_session_after_time_secs = get_close_session_after_time_secs()
361+
if not close_session_after_time_secs:
362+
continue
363+
if current_time > close_session_after_time_secs:
364+
logger.info("Grace period ended for %s, closing session", transport_id)
365+
await do_close()
366+
return
367+
368+
369+
async def setup_heartbeat(
370+
session_id: str,
371+
heartbeat_ms: float,
372+
heartbeats_until_dead: int,
373+
get_state: Callable[[], SessionState],
374+
get_closing_grace_period: Callable[[], float | None],
375+
close_websocket: Callable[[], Awaitable[None]],
376+
send_message: SendMessage,
377+
increment_and_get_heartbeat_misses: Callable[[], int],
378+
) -> None:
379+
while True:
380+
await asyncio.sleep(heartbeat_ms / 1000)
381+
state = get_state()
382+
if state == SessionState.CONNECTING:
383+
logger.debug("Websocket is not connected, not sending heartbeat")
384+
continue
385+
if state in TerminalStates:
386+
logger.debug(
387+
"Session is closed, no need to send heartbeat, state : "
388+
"%r close_session_after_this: %r",
389+
{state},
390+
{get_closing_grace_period()},
391+
)
392+
# session is closing / closed, no need to send heartbeat anymore
393+
return
394+
try:
395+
await send_message(
396+
stream_id="heartbeat",
397+
# TODO: make this a message class
398+
# https://github.com/replit/river/blob/741b1ea6d7600937ad53564e9cf8cd27a92ec36a/transport/message.ts#L42
399+
payload={
400+
"ack": 0,
401+
},
402+
control_flags=ACK_BIT,
403+
procedure_name=None,
404+
service_name=None,
405+
span=None,
406+
)
407+
408+
if increment_and_get_heartbeat_misses() > heartbeats_until_dead:
409+
if get_closing_grace_period() is not None:
410+
# already in grace period, no need to set again
411+
continue
412+
logger.info(
413+
"%r closing websocket because of heartbeat misses",
414+
session_id,
415+
)
416+
await close_websocket()
417+
continue
418+
except FailedSendingMessageException:
419+
# this is expected during websocket closed period
420+
continue

0 commit comments

Comments
 (0)