Skip to content

Commit 2a2cc58

Browse files
Upgrade to the new websocket impl
1 parent e51da74 commit 2a2cc58

File tree

6 files changed

+25
-37
lines changed

6 files changed

+25
-37
lines changed

src/replit_river/common_session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from aiochannel import Channel, ChannelClosed
77
from opentelemetry.trace import Span
88
from websockets import WebSocketCommonProtocol
9+
from websockets.asyncio.client import ClientConnection
910

1011
from replit_river.messages import (
1112
FailedSendingMessageException,
@@ -128,7 +129,7 @@ async def check_to_close_session(
128129

129130

130131
async def buffered_message_sender(
131-
get_ws: Callable[[], WebSocketCommonProtocol | None],
132+
get_ws: Callable[[], WebSocketCommonProtocol | ClientConnection | None],
132133
websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]],
133134
get_next_pending: Callable[[], TransportMessage | None],
134135
commit: Callable[[TransportMessage], None],

src/replit_river/messages.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from websockets import (
99
WebSocketCommonProtocol,
1010
)
11+
from websockets.asyncio.client import ClientConnection
1112

1213
from replit_river.rpc import (
1314
TransportMessage,
@@ -31,7 +32,7 @@ class FailedSendingMessageException(Exception):
3132

3233
async def send_transport_message(
3334
msg: TransportMessage,
34-
ws: WebSocketCommonProtocol,
35+
ws: WebSocketCommonProtocol | ClientConnection, # legacy | asyncio
3536
websocket_closed_callback: Callable[[], Coroutine[Any, Any, None]],
3637
) -> None:
3738
logger.debug("sending a message %r to ws %s", msg, ws)

src/replit_river/v2/client.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,6 @@ def translate_unknown_error(
5858
return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error")
5959

6060

61-
# Client[HandshakeSchema](
62-
# uri_and_metadata_factory=uri_and_metadata_factory,
63-
# client_id=self.client_id,
64-
# server_id="SERVER",
65-
# transport_options=TransportOptions(
66-
# session_disconnect_grace_ms=settings.RIVER_SESSION_DISCONNECT_GRACE_MS,
67-
# heartbeat_ms=settings.RIVER_HEARTBEAT_MS,
68-
# heartbeats_until_dead=settings.RIVER_HEARTBEATS_UNTIL_DEAD,
69-
# connection_retry_options=ConnectionRetryOptions(
70-
# base_interval_ms=settings.RIVER_CONNECTION_BASE_INTERVAL_MS,
71-
# max_jitter_ms=settings.RIVER_CONNECTION_MAX_JITTER_MS,
72-
# max_backoff_ms=settings.RIVER_CONNECTION_MAX_BACKOFF_MS,
73-
# attempt_budget_capacity=self.attempt_budget_capacity,
74-
# budget_restore_interval_ms=
75-
# settings.RIVER_CONNECTION_BUDGET_RESTORE_INTERVAL_MS,
76-
# max_retry=self.max_retry_count,
77-
# ),
78-
# ),
79-
# )
80-
81-
8261
class Client(Generic[HandshakeMetadataType]):
8362
def __init__(
8463
self,

src/replit_river/v2/client_session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
import logging
33
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, AsyncGenerator, Callable, Literal
5+
from typing import Any, AsyncGenerator, Callable, Literal, cast
66

77
import nanoid
88
import websockets
99
from aiochannel import Channel
1010
from aiochannel.errors import ChannelClosed
1111
from opentelemetry.trace import Span
12+
from websockets.asyncio.client import ClientConnection
1213
from websockets.exceptions import ConnectionClosed
1314
from websockets.frames import CloseCode
15+
from websockets.legacy.protocol import WebSocketCommonProtocol
1416

1517
from replit_river.common_session import buffered_message_sender
1618
from replit_river.error_schema import (
@@ -58,7 +60,7 @@ def __init__(
5860
transport_id: str,
5961
to_id: str,
6062
session_id: str,
61-
websocket: websockets.WebSocketCommonProtocol,
63+
websocket: ClientConnection,
6264
transport_options: TransportOptions,
6365
close_session_callback: CloseSessionCallback,
6466
retry_connection_callback: RetryConnectionCallback | None = None,
@@ -101,7 +103,11 @@ def get_next_pending() -> TransportMessage | None:
101103

102104
self._task_manager.create_task(
103105
buffered_message_sender(
104-
get_ws=lambda: self._ws_unwrapped if self.is_websocket_open() else None,
106+
get_ws=lambda: (
107+
cast(WebSocketCommonProtocol | ClientConnection, self._ws_unwrapped)
108+
if self.is_websocket_open()
109+
else None
110+
),
105111
websocket_closed_callback=self._begin_close_session_countdown,
106112
get_next_pending=get_next_pending,
107113
commit=commit,

src/replit_river/v2/client_transport.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import nanoid
77
import websockets
88
from pydantic import ValidationError
9-
from websockets import (
10-
WebSocketCommonProtocol,
11-
)
9+
import websockets.asyncio.client
10+
from websockets import WebSocketCommonProtocol
11+
from websockets.asyncio.client import ClientConnection
1212
from websockets.exceptions import ConnectionClosed
1313

1414
from replit_river.error_schema import (
@@ -129,7 +129,7 @@ async def _establish_new_connection(
129129
self,
130130
old_session: ClientSession | None = None,
131131
) -> tuple[
132-
WebSocketCommonProtocol,
132+
ClientConnection,
133133
ControlMessageHandshakeRequest[HandshakeMetadataType],
134134
ControlMessageHandshakeResponse,
135135
]:
@@ -159,7 +159,7 @@ async def _establish_new_connection(
159159

160160
try:
161161
uri_and_metadata = await self._uri_and_metadata_factory()
162-
ws = await websockets.connect(uri_and_metadata["uri"])
162+
ws = await websockets.asyncio.client.connect(uri_and_metadata["uri"])
163163
session_id: str
164164
if old_session:
165165
session_id = old_session.session_id
@@ -228,7 +228,7 @@ async def _send_handshake_request(
228228
self,
229229
session_id: str,
230230
handshake_metadata: HandshakeMetadataType | None,
231-
websocket: WebSocketCommonProtocol,
231+
websocket: ClientConnection,
232232
expected_session_state: ExpectedSessionState,
233233
) -> ControlMessageHandshakeRequest[HandshakeMetadataType]:
234234
handshake_request = ControlMessageHandshakeRequest[HandshakeMetadataType](
@@ -266,7 +266,7 @@ async def websocket_closed_callback() -> None:
266266
) from e
267267

268268
async def _get_handshake_response_msg(
269-
self, websocket: WebSocketCommonProtocol
269+
self, websocket: ClientConnection
270270
) -> TransportMessage:
271271
while True:
272272
try:
@@ -295,7 +295,7 @@ async def _establish_handshake(
295295
self,
296296
session_id: str,
297297
handshake_metadata: HandshakeMetadataType,
298-
websocket: WebSocketCommonProtocol,
298+
websocket: ClientConnection,
299299
old_session: ClientSession | None,
300300
) -> tuple[
301301
ControlMessageHandshakeRequest[HandshakeMetadataType],

src/replit_river/v2/session.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aiochannel import Channel
99
from opentelemetry.trace import Span, use_span
1010
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
11+
from websockets.asyncio.client import ClientConnection
1112
from websockets.frames import CloseCode
1213

1314
from replit_river.common_session import (
@@ -47,7 +48,7 @@ class Session:
4748

4849
# ws state
4950
_ws_connected: bool
50-
_ws_unwrapped: websockets.WebSocketCommonProtocol | None
51+
_ws_unwrapped: ClientConnection | None
5152
_heartbeat_misses: int
5253
_retry_connection_callback: RetryConnectionCallback | None
5354

@@ -66,7 +67,7 @@ def __init__(
6667
transport_id: str,
6768
to_id: str,
6869
session_id: str,
69-
websocket: websockets.WebSocketCommonProtocol,
70+
websocket: ClientConnection,
7071
transport_options: TransportOptions,
7172
close_session_callback: CloseSessionCallback,
7273
retry_connection_callback: RetryConnectionCallback | None = None,
@@ -162,7 +163,7 @@ async def _begin_close_session_countdown(self) -> None:
162163
self._ws_connected = False
163164

164165
async def replace_with_new_websocket(
165-
self, new_ws: websockets.WebSocketCommonProtocol
166+
self, new_ws: ClientConnection
166167
) -> None:
167168
if self._ws_unwrapped and new_ws.id != self._ws_unwrapped.id:
168169
self._task_manager.create_task(

0 commit comments

Comments
 (0)