Skip to content

Commit cf27ae6

Browse files
Remove is_server
1 parent 4f68627 commit cf27ae6

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

src/replit_river/client_session.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from typing import Any, AsyncGenerator, Callable, Coroutine
66

77
import nanoid # type: ignore
8+
import websockets
89
from aiochannel import Channel
910
from aiochannel.errors import ChannelClosed
1011
from opentelemetry.trace import Span
11-
import websockets
1212
from websockets.exceptions import ConnectionClosed
1313

1414
from replit_river.error_schema import (
@@ -72,6 +72,15 @@ def __init__(
7272
retry_connection_callback=retry_connection_callback,
7373
)
7474

75+
async def do_close_websocket() -> None:
76+
await self.close_websocket(
77+
self._ws_wrapper,
78+
should_retry=True,
79+
)
80+
await self._begin_close_session_countdown()
81+
82+
self._setup_heartbeats_task(do_close_websocket)
83+
7584
async def start_serve_responses(self) -> None:
7685
self._task_manager.create_task(self.serve())
7786

src/replit_river/server_session.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
websocket: websockets.WebSocketCommonProtocol,
5151
transport_options: TransportOptions,
5252
handlers: dict[tuple[str, str], tuple[str, GenericRpcHandlerBuilder]],
53-
close_session_callback: Callable[["Session"], Coroutine[Any, Any, Any]],
53+
close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]],
5454
retry_connection_callback: (
5555
Callable[
5656
[],
@@ -68,9 +68,17 @@ def __init__(
6868
close_session_callback=close_session_callback,
6969
retry_connection_callback=retry_connection_callback,
7070
)
71-
self._is_server = True
7271
self._handlers = handlers
7372

73+
async def do_close_websocket() -> None:
74+
await self.close_websocket(
75+
self._ws_wrapper,
76+
should_retry=False,
77+
)
78+
await self._begin_close_session_countdown()
79+
80+
self._setup_heartbeats_task(do_close_websocket)
81+
7482
async def start_serve_responses(self) -> None:
7583
self._task_manager.create_task(self.serve())
7684

src/replit_river/session.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import enum
33
import logging
4-
from typing import Any, Callable, Coroutine, Protocol
4+
from typing import Any, Awaitable, Callable, Coroutine, Protocol
55

66
import nanoid # type: ignore
77
import websockets
@@ -62,7 +62,6 @@ class SessionState(enum.Enum):
6262

6363
class Session:
6464
"""Common functionality shared between client_session and server_session"""
65-
_is_server: bool
6665

6766
def __init__(
6867
self,
@@ -83,7 +82,6 @@ def __init__(
8382
self._transport_id = transport_id
8483
self._to_id = to_id
8584
self.session_id = session_id
86-
self._is_server = False
8785
self._transport_options = transport_options
8886

8987
# session state, only modified during closing
@@ -108,16 +106,10 @@ def __init__(
108106
self._buffer = MessageBuffer(self._transport_options.buffer_size)
109107
self._task_manager = BackgroundTaskManager()
110108

111-
self._setup_heartbeats_task()
112-
113-
def _setup_heartbeats_task(self) -> None:
114-
async def do_close_websocket() -> None:
115-
await self.close_websocket(
116-
self._ws_wrapper,
117-
should_retry=not self._is_server,
118-
)
119-
await self._begin_close_session_countdown()
120-
109+
def _setup_heartbeats_task(
110+
self,
111+
do_close_websocket: Callable[[], Awaitable[None]],
112+
) -> None:
121113
def increment_and_get_heartbeat_misses() -> int:
122114
self._heartbeat_misses += 1
123115
return self._heartbeat_misses

0 commit comments

Comments
 (0)