Skip to content

Commit 3d4cbf0

Browse files
committed
Merge branch 'main' into th-stricter-error-types
2 parents 672492f + 2983732 commit 3d4cbf0

39 files changed

+2062
-779
lines changed

scripts/lint/src/lint/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def raise_err(code: int) -> None:
1111

1212
def main() -> None:
1313
fix = ["--fix"] if "--fix" in sys.argv else []
14-
raise_err(os.system(" ".join(["ruff", "check", "src"] + fix)))
15-
raise_err(os.system("ruff format src"))
14+
raise_err(os.system(" ".join(["ruff", "check", "src", "scripts", "tests"] + fix)))
15+
raise_err(os.system("ruff format src scripts tests"))
1616
raise_err(os.system("mypy src"))
1717
raise_err(os.system("pyright src"))

scripts/parity/check_parity.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Literal, TypedDict, TypeVar, Union
1+
from typing import Any, Callable, Literal, TypedDict, TypeVar
22

33
import pyd
44
import tyd
@@ -85,35 +85,37 @@ def testAgenttoollanguageserverOpendocumentInput() -> None:
8585
)
8686

8787

88-
kind_type = Union[
89-
Literal[1],
90-
Literal[2],
91-
Literal[3],
92-
Literal[4],
93-
Literal[5],
94-
Literal[6],
95-
Literal[7],
96-
Literal[8],
97-
Literal[9],
98-
Literal[10],
99-
Literal[11],
100-
Literal[12],
101-
Literal[13],
102-
Literal[14],
103-
Literal[15],
104-
Literal[16],
105-
Literal[17],
106-
Literal[18],
107-
Literal[19],
108-
Literal[20],
109-
Literal[21],
110-
Literal[22],
111-
Literal[23],
112-
Literal[24],
113-
Literal[25],
114-
Literal[26],
115-
None,
116-
]
88+
kind_type = (
89+
Literal[
90+
1,
91+
2,
92+
3,
93+
4,
94+
5,
95+
6,
96+
7,
97+
8,
98+
9,
99+
10,
100+
11,
101+
12,
102+
13,
103+
14,
104+
15,
105+
16,
106+
17,
107+
18,
108+
19,
109+
20,
110+
21,
111+
22,
112+
23,
113+
24,
114+
25,
115+
26,
116+
]
117+
| None
118+
)
117119

118120

119121
def testAgenttoollanguageserverGetcodesymbolInput() -> None:

scripts/parity/gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
import string
3-
from typing import Callable, Optional, TypeVar
3+
from typing import Callable, TypeVar
44

55
A = TypeVar("A")
66

@@ -37,7 +37,7 @@ def gen_choice(choices: list[A]) -> Callable[[], A]:
3737
return lambda: random.choice(choices)
3838

3939

40-
def gen_opt(gen_x: Callable[[], A]) -> Callable[[], Optional[A]]:
40+
def gen_opt(gen_x: Callable[[], A]) -> Callable[[], A | None]:
4141
return lambda: gen_x() if gen_bool() else None
4242

4343

src/replit_river/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .client import Client
22
from .error_schema import RiverError
33
from .rpc import (
4-
GenericRpcHandler,
4+
GenericRpcHandlerBuilder,
55
GrpcContext,
66
rpc_method_handler,
77
stream_method_handler,
@@ -15,7 +15,7 @@
1515
"Server",
1616
"GrpcContext",
1717
"RiverError",
18-
"GenericRpcHandler",
18+
"GenericRpcHandlerBuilder",
1919
"rpc_method_handler",
2020
"subscription_method_handler",
2121
"upload_method_handler",

src/replit_river/client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import contextmanager
44
from dataclasses import dataclass
55
from datetime import timedelta
6-
from typing import Any, AsyncGenerator, Generator, Generic, Literal, Optional, Union
6+
from typing import Any, AsyncGenerator, Generator, Generic, Literal
77

88
from opentelemetry import trace
99
from opentelemetry.trace import Span, SpanKind, Status, StatusCode
@@ -119,9 +119,9 @@ async def send_upload(
119119
self,
120120
service_name: str,
121121
procedure_name: str,
122-
init: Optional[InitType],
122+
init: InitType | None,
123123
request: AsyncIterable[RequestType],
124-
init_serializer: Optional[Callable[[InitType], Any]],
124+
init_serializer: Callable[[InitType], Any] | None,
125125
request_serializer: Callable[[RequestType], Any],
126126
response_deserializer: Callable[[Any], ResponseType],
127127
error_deserializer: Callable[[Any], ErrorType],
@@ -148,7 +148,7 @@ async def send_subscription(
148148
request_serializer: Callable[[RequestType], Any],
149149
response_deserializer: Callable[[Any], ResponseType],
150150
error_deserializer: Callable[[Any], ErrorType],
151-
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
151+
) -> AsyncGenerator[ResponseType | RiverError, None]:
152152
with _trace_procedure(
153153
"subscription", service_name, procedure_name
154154
) as span_handle:
@@ -170,13 +170,13 @@ async def send_stream(
170170
self,
171171
service_name: str,
172172
procedure_name: str,
173-
init: Optional[InitType],
173+
init: InitType | None,
174174
request: AsyncIterable[RequestType],
175-
init_serializer: Optional[Callable[[InitType], Any]],
175+
init_serializer: Callable[[InitType], Any] | None,
176176
request_serializer: Callable[[RequestType], Any],
177177
response_deserializer: Callable[[Any], ResponseType],
178178
error_deserializer: Callable[[Any], ErrorType],
179-
) -> AsyncGenerator[Union[ResponseType, RiverError], None]:
179+
) -> AsyncGenerator[ResponseType | RiverError, None]:
180180
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
181181
session = await self._transport.get_or_create_session()
182182
async for msg in session.send_stream(
@@ -204,8 +204,8 @@ class _SpanHandle:
204204

205205
def set_status(
206206
self,
207-
status: Union[Status, StatusCode],
208-
description: Optional[str] = None,
207+
status: Status | StatusCode,
208+
description: str | None = None,
209209
) -> None:
210210
if self.did_set_status:
211211
return

src/replit_river/client_session.py

Lines changed: 144 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import logging
33
from collections.abc import AsyncIterable
44
from datetime import timedelta
5-
from typing import Any, AsyncGenerator, Callable, Optional, Union
5+
from typing import Any, AsyncGenerator, Callable, Coroutine
66

77
import nanoid # type: ignore
8+
import websockets
89
from aiochannel import Channel
910
from aiochannel.errors import ChannelClosed
1011
from opentelemetry.trace import Span
12+
from websockets.exceptions import ConnectionClosed
1113

14+
from replit_river.common_session import add_msg_to_stream
1215
from replit_river.error_schema import (
1316
ERROR_CODE_CANCEL,
1417
ERROR_CODE_STREAM_CLOSED,
@@ -17,10 +20,20 @@
1720
StreamClosedRiverServiceException,
1821
exception_from_message,
1922
)
23+
from replit_river.messages import (
24+
FailedSendingMessageException,
25+
parse_transport_msg,
26+
)
27+
from replit_river.seq_manager import (
28+
IgnoreMessageException,
29+
InvalidMessageException,
30+
OutOfOrderMessageException,
31+
)
2032
from replit_river.session import Session
21-
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
33+
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE, TransportOptions
2234

2335
from .rpc import (
36+
ACK_BIT,
2437
STREAM_CLOSED_BIT,
2538
STREAM_OPEN_BIT,
2639
ErrorType,
@@ -33,6 +46,129 @@
3346

3447

3548
class ClientSession(Session):
49+
def __init__(
50+
self,
51+
transport_id: str,
52+
to_id: str,
53+
session_id: str,
54+
websocket: websockets.WebSocketCommonProtocol,
55+
transport_options: TransportOptions,
56+
close_session_callback: Callable[[Session], Coroutine[Any, Any, Any]],
57+
retry_connection_callback: (
58+
Callable[
59+
[],
60+
Coroutine[Any, Any, Any],
61+
]
62+
| None
63+
) = None,
64+
) -> None:
65+
super().__init__(
66+
transport_id=transport_id,
67+
to_id=to_id,
68+
session_id=session_id,
69+
websocket=websocket,
70+
transport_options=transport_options,
71+
close_session_callback=close_session_callback,
72+
retry_connection_callback=retry_connection_callback,
73+
)
74+
75+
async def do_close_websocket() -> None:
76+
await self.close_websocket(
77+
self._ws_wrapper,
78+
should_retry=True,
79+
)
80+
await self._begin_close_session_countdown()
81+
82+
self._setup_heartbeats_task(do_close_websocket)
83+
84+
async def start_serve_responses(self) -> None:
85+
self._task_manager.create_task(self.serve())
86+
87+
async def serve(self) -> None:
88+
"""Serve messages from the websocket."""
89+
self._reset_session_close_countdown()
90+
try:
91+
try:
92+
await self._handle_messages_from_ws()
93+
except ConnectionClosed:
94+
if self._retry_connection_callback:
95+
self._task_manager.create_task(self._retry_connection_callback())
96+
97+
await self._begin_close_session_countdown()
98+
logger.debug("ConnectionClosed while serving", exc_info=True)
99+
except FailedSendingMessageException:
100+
# Expected error if the connection is closed.
101+
logger.debug(
102+
"FailedSendingMessageException while serving", exc_info=True
103+
)
104+
except Exception:
105+
logger.exception("caught exception at message iterator")
106+
except ExceptionGroup as eg:
107+
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))
108+
if unhandled:
109+
raise ExceptionGroup(
110+
"Unhandled exceptions on River server", unhandled.exceptions
111+
)
112+
113+
async def _handle_messages_from_ws(self) -> None:
114+
logger.debug(
115+
"%s start handling messages from ws %s",
116+
"client",
117+
self._ws_wrapper.id,
118+
)
119+
try:
120+
ws_wrapper = self._ws_wrapper
121+
async for message in ws_wrapper.ws:
122+
try:
123+
if not await ws_wrapper.is_open():
124+
# We should not process messages if the websocket is closed.
125+
break
126+
msg = parse_transport_msg(message, self._transport_options)
127+
128+
logger.debug(f"{self._transport_id} got a message %r", msg)
129+
130+
# Update bookkeeping
131+
await self._seq_manager.check_seq_and_update(msg)
132+
await self._buffer.remove_old_messages(
133+
self._seq_manager.receiver_ack,
134+
)
135+
self._reset_session_close_countdown()
136+
137+
if msg.controlFlags & ACK_BIT != 0:
138+
continue
139+
async with self._stream_lock:
140+
stream = self._streams.get(msg.streamId, None)
141+
if msg.controlFlags & STREAM_OPEN_BIT == 0:
142+
if not stream:
143+
logger.warning("no stream for %s", msg.streamId)
144+
raise IgnoreMessageException(
145+
"no stream for message, ignoring"
146+
)
147+
await add_msg_to_stream(msg, stream)
148+
else:
149+
raise InvalidMessageException(
150+
"Client should not receive stream open bit"
151+
)
152+
153+
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
154+
if stream:
155+
stream.close()
156+
async with self._stream_lock:
157+
del self._streams[msg.streamId]
158+
except IgnoreMessageException:
159+
logger.debug("Ignoring transport message", exc_info=True)
160+
continue
161+
except OutOfOrderMessageException:
162+
logger.exception("Out of order message, closing connection")
163+
await ws_wrapper.close()
164+
return
165+
except InvalidMessageException:
166+
logger.exception("Got invalid transport message, closing session")
167+
await self.close()
168+
return
169+
except ConnectionClosed as e:
170+
raise e
171+
36172
async def send_rpc(
37173
self,
38174
service_name: str,
@@ -102,9 +238,9 @@ async def send_upload(
102238
self,
103239
service_name: str,
104240
procedure_name: str,
105-
init: Optional[InitType],
241+
init: InitType | None,
106242
request: AsyncIterable[RequestType],
107-
init_serializer: Optional[Callable[[InitType], Any]],
243+
init_serializer: Callable[[InitType], Any] | None,
108244
request_serializer: Callable[[RequestType], Any],
109245
response_deserializer: Callable[[Any], ResponseType],
110246
error_deserializer: Callable[[Any], ErrorType],
@@ -194,7 +330,7 @@ async def send_subscription(
194330
response_deserializer: Callable[[Any], ResponseType],
195331
error_deserializer: Callable[[Any], ErrorType],
196332
span: Span,
197-
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
333+
) -> AsyncGenerator[ResponseType | ErrorType, None]:
198334
"""Sends a subscription request to the server.
199335
200336
Expects the input and output be messages that will be msgpacked.
@@ -241,14 +377,14 @@ async def send_stream(
241377
self,
242378
service_name: str,
243379
procedure_name: str,
244-
init: Optional[InitType],
380+
init: InitType | None,
245381
request: AsyncIterable[RequestType],
246-
init_serializer: Optional[Callable[[InitType], Any]],
382+
init_serializer: Callable[[InitType], Any] | None,
247383
request_serializer: Callable[[RequestType], Any],
248384
response_deserializer: Callable[[Any], ResponseType],
249385
error_deserializer: Callable[[Any], ErrorType],
250386
span: Span,
251-
) -> AsyncGenerator[Union[ResponseType, ErrorType], None]:
387+
) -> AsyncGenerator[ResponseType | ErrorType, None]:
252388
"""Sends a subscription request to the server.
253389
254390
Expects the input and output be messages that will be msgpacked.

0 commit comments

Comments
 (0)