Skip to content

Commit 5bd3ad1

Browse files
Remove is_server
1 parent 3296e34 commit 5bd3ad1

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

src/replit_river/client_session.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from typing import Any, AsyncGenerator, Callable, Coroutine
66

77
import nanoid # type: ignore
8+
import websockets
89
from aiochannel import Channel
910
from aiochannel.errors import ChannelClosed
1011
from opentelemetry.trace import Span
11-
import websockets
1212
from websockets.exceptions import ConnectionClosed
1313

1414
from replit_river.error_schema import (
@@ -72,6 +72,15 @@ def __init__(
7272
retry_connection_callback=retry_connection_callback,
7373
)
7474

75+
async def do_close_websocket() -> None:
76+
await self.close_websocket(
77+
self._ws_wrapper,
78+
should_retry=True,
79+
)
80+
await self._begin_close_session_countdown()
81+
82+
self._setup_heartbeats_task(do_close_websocket)
83+
7584
async def start_serve_responses(self) -> None:
7685
self._task_manager.create_task(self.serve())
7786

src/replit_river/server_session.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
websocket: websockets.WebSocketCommonProtocol,
5151
transport_options: TransportOptions,
5252
handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]],
53-
close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]],
53+
close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]],
5454
retry_connection_callback: (
5555
Callable[
5656
[],
@@ -68,9 +68,17 @@ def __init__(
6868
close_session_callback=close_session_callback,
6969
retry_connection_callback=retry_connection_callback,
7070
)
71-
self._is_server = True
7271
self._handlers = handlers
7372

73+
async def do_close_websocket() -> None:
74+
await self.close_websocket(
75+
self._ws_wrapper,
76+
should_retry=False,
77+
)
78+
await self._begin_close_session_countdown()
79+
80+
self._setup_heartbeats_task(do_close_websocket)
81+
7482
async def start_serve_responses(self) -> None:
7583
self._task_manager.create_task(self.serve())
7684

src/replit_river/session.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Any, Callable, Coroutine
3+
from typing import Any, Awaitable, Callable, Coroutine
44

55
import nanoid # type: ignore
66
import websockets
@@ -37,7 +37,6 @@
3737

3838
class Session:
3939
"""Common functionality shared between client_session and server_session"""
40-
_is_server: bool
4140

4241
def __init__(
4342
self,
@@ -58,7 +57,6 @@ def __init__(
5857
self._transport_id = transport_id
5958
self._to_id = to_id
6059
self.session_id = session_id
61-
self._is_server = False
6260
self._transport_options = transport_options
6361

6462
# session state, only modified during closing
@@ -83,16 +81,10 @@ def __init__(
8381
self._buffer = MessageBuffer(self._transport_options.buffer_size)
8482
self._task_manager = BackgroundTaskManager()
8583

86-
self._setup_heartbeats_task()
87-
88-
def _setup_heartbeats_task(self) -> None:
89-
async def do_close_websocket() -> None:
90-
await self.close_websocket(
91-
self._ws_wrapper,
92-
should_retry=not self._is_server,
93-
)
94-
await self._begin_close_session_countdown()
95-
84+
def _setup_heartbeats_task(
85+
self,
86+
do_close_websocket: Callable[[], Awaitable[None]],
87+
) -> None:
9688
def increment_and_get_heartbeat_misses() -> int:
9789
self._heartbeat_misses += 1
9890
return self._heartbeat_misses

0 commit comments

Comments
 (0)