Skip to content

Commit ead7ce3

Browse files
Implement RPC timeouts (#123)
Why === Requiring an explicit timeout to be set on all RPC calls will prevent runaway execution. What changed ============ Add a required `timeout: timedelta` parameter to all rpc methods Test plan ========= TBD. https://github.com/replit/river-babel would be nice for this, but it's still undergoing some upgrade pains.
1 parent 798e59e commit ead7ce3

File tree

19 files changed

+409
-17
lines changed

19 files changed

+409
-17
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ dev-dependencies = [
5252

5353
[tool.ruff]
5454
lint.select = ["F", "E", "W", "I001"]
55+
exclude = ["*/generated/*"]
5556

5657
# Should be kept in sync with mypy.ini in the project root.
5758
# The VSCode mypy extension can only read /mypy.ini.

replit_river/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
33
from contextlib import contextmanager
4+
from datetime import timedelta
45
from typing import Any, Generator, Generic, Literal, Optional, Union
56

67
from opentelemetry import trace
@@ -60,6 +61,7 @@ async def send_rpc(
6061
request_serializer: Callable[[RequestType], Any],
6162
response_deserializer: Callable[[Any], ResponseType],
6263
error_deserializer: Callable[[Any], ErrorType],
64+
timeout: timedelta,
6365
) -> ResponseType:
6466
with _trace_procedure("rpc", service_name, procedure_name) as span:
6567
session = await self._transport.get_or_create_session()
@@ -71,6 +73,7 @@ async def send_rpc(
7173
response_deserializer,
7274
error_deserializer,
7375
span,
76+
timeout,
7477
)
7578

7679
async def send_upload(

replit_river/client_session.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import asyncio
12
import logging
23
from collections.abc import AsyncIterable, AsyncIterator
4+
from datetime import timedelta
35
from typing import Any, Callable, Optional, Union
46

57
import nanoid # type: ignore
@@ -8,6 +10,7 @@
810
from opentelemetry.trace import Span
911

1012
from replit_river.error_schema import (
13+
ERROR_CODE_CANCEL,
1114
ERROR_CODE_STREAM_CLOSED,
1215
RiverException,
1316
RiverServiceException,
@@ -39,6 +42,7 @@ async def send_rpc(
3942
response_deserializer: Callable[[Any], ResponseType],
4043
error_deserializer: Callable[[Any], ErrorType],
4144
span: Span,
45+
timeout: timedelta,
4246
) -> ResponseType:
4347
"""Sends a single RPC request to the server.
4448
@@ -58,7 +62,19 @@ async def send_rpc(
5862
# Handle potential errors during communication
5963
try:
6064
try:
61-
response = await output.get()
65+
async with asyncio.timeout(int(timeout.total_seconds())):
66+
response = await output.get()
67+
except asyncio.TimeoutError as e:
68+
# TODO(dstewart) After protocol v2, change this to STREAM_CANCEL_BIT
69+
await self.send_message(
70+
stream_id=stream_id,
71+
control_flags=STREAM_CLOSED_BIT,
72+
payload={"type": "CLOSE"},
73+
service_name=service_name,
74+
procedure_name=procedure_name,
75+
span=span,
76+
)
77+
raise RiverException(ERROR_CODE_CANCEL, str(e)) from e
6278
except ChannelClosed as e:
6379
raise RiverServiceException(
6480
ERROR_CODE_STREAM_CLOSED,

replit_river/codegen/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
# Code generated by river.codegen. DO NOT EDIT.
5757
from collections.abc import AsyncIterable, AsyncIterator
5858
from typing import Any
59+
import datetime
5960
6061
from pydantic import TypeAdapter
6162
@@ -857,6 +858,7 @@ def __init__(self, client: river.Client[Any]):
857858
async def {name}(
858859
self,
859860
input: {render_type_expr(input_type)},
861+
timeout: datetime.timedelta,
860862
) -> {render_type_expr(output_type)}:
861863
return await self.client.send_rpc(
862864
{repr(schema_name)},
@@ -865,6 +867,7 @@ async def {name}(
865867
{reindent(" ", render_input_method)},
866868
{reindent(" ", parse_output_method)},
867869
{reindent(" ", parse_error_method)},
870+
timeout,
868871
)
869872
""",
870873
)

replit_river/rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
]
5252
ACK_BIT = 0x0001
5353
STREAM_OPEN_BIT = 0x0002
54-
STREAM_CLOSED_BIT = 0x0004
54+
STREAM_CLOSED_BIT = 0x0004 # Synonymous with the cancel bit in v2
5555

5656
# these codes are retriable
5757
# if the server sends a response with one of these codes,

replit_river/session.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,14 @@ async def _handle_messages_from_ws(
203203
)
204204
await self._add_msg_to_stream(msg, stream)
205205
else:
206-
stream = await self._open_stream_and_call_handler(
207-
msg, stream, tg
208-
)
206+
# TODO(dstewart) This looks like it opens a new call to handler
207+
# on ever ws message, instead of demuxing and
208+
# routing.
209+
_stream = await self._open_stream_and_call_handler(msg, tg)
210+
if not stream:
211+
async with self._stream_lock:
212+
self._streams[msg.streamId] = _stream
213+
stream = _stream
209214

210215
if msg.controlFlags & STREAM_CLOSED_BIT != 0:
211216
if stream:
@@ -457,7 +462,6 @@ async def close_websocket(
457462
async def _open_stream_and_call_handler(
458463
self,
459464
msg: TransportMessage,
460-
stream: Optional[Channel],
461465
tg: Optional[asyncio.TaskGroup],
462466
) -> Channel:
463467
if not self._is_server:
@@ -496,9 +500,6 @@ async def _open_stream_and_call_handler(
496500
await input_stream.put(msg.payload)
497501
except (RuntimeError, ChannelClosed) as e:
498502
raise InvalidMessageException(e) from e
499-
if not stream:
500-
async with self._stream_lock:
501-
self._streams[msg.streamId] = input_stream
502503
# Start the handler.
503504
self._task_manager.create_task(
504505
handler_func(msg.from_, input_stream, output_stream), tg

replit_river/task_manager.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Any, Optional, Set
3+
from typing import Coroutine, Optional, Set
44

55
from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException
66

@@ -11,7 +11,7 @@ class BackgroundTaskManager:
1111
"""Manages background tasks and logs exceptions."""
1212

1313
def __init__(self) -> None:
14-
self.background_tasks: Set[asyncio.Task] = set()
14+
self.background_tasks: Set[asyncio.Task[None]] = set()
1515

1616
async def cancel_all_tasks(self) -> None:
1717
"""Asynchronously cancels all tasks managed by this instance."""
@@ -21,8 +21,8 @@ async def cancel_all_tasks(self) -> None:
2121

2222
@staticmethod
2323
async def cancel_task(
24-
task_to_remove: asyncio.Task[Any],
25-
background_tasks: Set[asyncio.Task],
24+
task_to_remove: asyncio.Task[None],
25+
background_tasks: Set[asyncio.Task[None]],
2626
) -> None:
2727
"""Cancels a given task and ensures it is removed from the set of managed tasks.
2828
@@ -50,8 +50,8 @@ async def cancel_task(
5050

5151
def _task_done_callback(
5252
self,
53-
task_to_remove: asyncio.Task[Any],
54-
background_tasks: Set[asyncio.Task],
53+
task_to_remove: asyncio.Task[None],
54+
background_tasks: Set[asyncio.Task[None]],
5555
) -> None:
5656
"""Callback to be executed when a task is done. It removes the task from the set
5757
and logs any exceptions.
@@ -83,8 +83,8 @@ def _task_done_callback(
8383
)
8484

8585
def create_task(
86-
self, fn: Any, tg: Optional[asyncio.TaskGroup] = None
87-
) -> asyncio.Task:
86+
self, fn: Coroutine[None, None, None], tg: Optional[asyncio.TaskGroup] = None
87+
) -> asyncio.Task[None]:
8888
"""Creates a task from a callable and adds it to the background tasks set.
8989
9090
Args:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from pydantic import BaseModel
3+
from typing import Literal
4+
5+
import replit_river as river
6+
7+
8+
from .test_service import Test_ServiceService
9+
10+
11+
class RpcClient:
12+
def __init__(self, client: river.Client[Literal[None]]):
13+
self.test_service = Test_ServiceService(client)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
from typing import Any
4+
import datetime
5+
6+
from pydantic import TypeAdapter
7+
8+
from replit_river.error_schema import RiverError
9+
import replit_river as river
10+
11+
12+
from .rpc_method import encode_Rpc_MethodInput, Rpc_MethodInput, Rpc_MethodOutput
13+
14+
15+
class Test_ServiceService:
16+
def __init__(self, client: river.Client[Any]):
17+
self.client = client
18+
19+
async def rpc_method(
20+
self,
21+
input: Rpc_MethodInput,
22+
timeout: datetime.timedelta,
23+
) -> Rpc_MethodOutput:
24+
return await self.client.send_rpc(
25+
"test_service",
26+
"rpc_method",
27+
input,
28+
encode_Rpc_MethodInput,
29+
lambda x: TypeAdapter(Rpc_MethodOutput).validate_python(
30+
x # type: ignore[arg-type]
31+
),
32+
lambda x: TypeAdapter(RiverError).validate_python(
33+
x # type: ignore[arg-type]
34+
),
35+
timeout,
36+
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# ruff: noqa
2+
# Code generated by river.codegen. DO NOT EDIT.
3+
from collections.abc import AsyncIterable, AsyncIterator
4+
import datetime
5+
from typing import (
6+
Any,
7+
Callable,
8+
Dict,
9+
List,
10+
Literal,
11+
Optional,
12+
Mapping,
13+
Union,
14+
Tuple,
15+
TypedDict,
16+
)
17+
18+
from pydantic import BaseModel, Field, TypeAdapter
19+
from replit_river.error_schema import RiverError
20+
21+
import replit_river as river
22+
23+
24+
encode_Rpc_MethodInput: Callable[["Rpc_MethodInput"], Any] = lambda x: {
25+
k: v
26+
for (k, v) in (
27+
{
28+
"data": x.get("data"),
29+
}
30+
).items()
31+
if v is not None
32+
}
33+
34+
35+
class Rpc_MethodInput(TypedDict):
36+
data: str
37+
38+
39+
class Rpc_MethodOutput(BaseModel):
40+
data: str

0 commit comments

Comments
 (0)