Skip to content

Commit ede2730

Browse files
[Serve] Add broadcast API for deployment handles that broadcasts the same RPC across all live replicas of a deployment (#61472)
Signed-off-by: bittoby <bittoby@users.noreply.github.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Co-authored-by: bittoby <bittoby@users.noreply.github.com> Co-authored-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
1 parent fae2be4 commit ede2730

File tree

8 files changed

+795
-5
lines changed

8 files changed

+795
-5
lines changed

doc/source/serve/api/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ See the [model composition guide](serve-model-composition) for how to update cod
5656
serve.handle.DeploymentHandle
5757
serve.handle.DeploymentResponse
5858
serve.handle.DeploymentResponseGenerator
59+
serve.handle.DeploymentBroadcastResponse
5960
```
6061

6162
### Running Applications

python/ray/serve/_private/local_testing_mode.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import queue
66
import time
77
from functools import wraps
8-
from typing import Any, Callable, Coroutine, Dict, Optional, Tuple, Union
8+
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Union
99

1010
import ray
1111
from ray import cloudpickle
@@ -341,6 +341,20 @@ def generator_result_callback(item: Any):
341341
)
342342
return noop_future
343343

344+
async def broadcast(
345+
self,
346+
request_meta: RequestMetadata,
347+
*request_args,
348+
**request_kwargs,
349+
) -> List[ReplicaResult]:
350+
"""Broadcast in local testing mode calls the single local replica."""
351+
result_future = self.assign_request(
352+
request_meta, *request_args, **request_kwargs
353+
)
354+
# In local testing mode there is only one replica.
355+
replica_result = result_future.result()
356+
return [replica_result]
357+
344358
def shutdown(self):
345359
noop_future = concurrent.futures.Future()
346360
noop_future.set_result(None)

python/ray/serve/_private/router.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections import defaultdict
1010
from collections.abc import MutableMapping
1111
from contextlib import contextmanager
12+
from dataclasses import replace
1213
from functools import lru_cache, partial
1314
from typing import (
1415
Any,
@@ -512,6 +513,11 @@ async def shutdown(self):
512513

513514

514515
class Router(ABC):
516+
@property
517+
def event_loop(self) -> Optional[AbstractEventLoop]:
518+
"""The event loop the router runs on, or None (e.g. local testing)."""
519+
return getattr(self, "_asyncio_loop", None)
520+
515521
@abstractmethod
516522
def running_replicas_populated(self) -> bool:
517523
pass
@@ -525,6 +531,15 @@ def assign_request(
525531
) -> concurrent.futures.Future[ReplicaResult]:
526532
pass
527533

534+
@abstractmethod
535+
async def broadcast(
536+
self,
537+
request_meta: RequestMetadata,
538+
*request_args,
539+
**request_kwargs,
540+
) -> List[ReplicaResult]:
541+
pass
542+
528543
@abstractmethod
529544
def shutdown(self) -> concurrent.futures.Future:
530545
pass
@@ -1157,6 +1172,115 @@ async def assign_request(
11571172
if exc:
11581173
set_span_exception(exc, escaped=True)
11591174

1175+
async def broadcast(
1176+
self,
1177+
request_meta: RequestMetadata,
1178+
*request_args,
1179+
**request_kwargs,
1180+
) -> List[ReplicaResult]:
1181+
"""Send a request to all running replicas in parallel.
1182+
1183+
Bypasses the normal load-balancing path and sends the request
1184+
directly to every replica. Waits for the request router to be
1185+
initialized so the replica set is populated.
1186+
"""
1187+
# Propagate tracing context, matching assign_request behavior.
1188+
if is_span_recording():
1189+
propagate_context = create_propagated_context()
1190+
request_meta.tracing_context = propagate_context
1191+
else:
1192+
request_meta.tracing_context = None
1193+
1194+
if not self._deployment_available:
1195+
raise DeploymentUnavailableError(self.deployment_id)
1196+
1197+
await self._request_router_initialized.wait()
1198+
1199+
if not self._deployment_available:
1200+
raise DeploymentUnavailableError(self.deployment_id)
1201+
1202+
replicas: List[RunningReplica] = list(
1203+
self.request_router.curr_replicas.values()
1204+
)
1205+
if not replicas:
1206+
raise DeploymentUnavailableError(self.deployment_id)
1207+
1208+
# Resolve arguments (e.g. DeploymentResponse objects) before sending.
1209+
pr = PendingRequest(
1210+
args=list(request_args),
1211+
kwargs=dict(request_kwargs),
1212+
metadata=request_meta,
1213+
)
1214+
await self._resolve_request_arguments(pr)
1215+
1216+
results: List[ReplicaResult] = []
1217+
for replica in replicas:
1218+
replica_pr = PendingRequest(
1219+
args=list(pr.args),
1220+
kwargs=dict(pr.kwargs),
1221+
metadata=replace(
1222+
request_meta,
1223+
internal_request_id=generate_request_id(),
1224+
),
1225+
)
1226+
replica_pr.resolved = True
1227+
try:
1228+
result = replica.try_send_request(replica_pr, with_rejection=False)
1229+
except ActorDiedError:
1230+
# Replica has died but controller hasn't notified the router yet.
1231+
# Skip this replica and continue broadcasting to healthy replicas.
1232+
self.request_router.on_replica_actor_died(replica.replica_id)
1233+
logger.warning(
1234+
f"{replica.replica_id} will not be considered for future "
1235+
"requests because it has died."
1236+
)
1237+
continue
1238+
except ActorUnavailableError:
1239+
# Replica is temporarily unavailable. Invalidate the cache entry
1240+
# and continue broadcasting to other replicas.
1241+
self.request_router.on_replica_actor_unavailable(replica.replica_id)
1242+
logger.warning(f"{replica.replica_id} is temporarily unavailable.")
1243+
continue
1244+
1245+
# Proactively update the queue length cache.
1246+
self.request_router.on_send_request(replica.replica_id)
1247+
1248+
# Track running requests and register callback for completion
1249+
# handling, matching the pattern in _route_and_send_request_once.
1250+
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE:
1251+
self._metrics_manager.inc_num_running_requests_for_replica(
1252+
replica.replica_id
1253+
)
1254+
# NOTE: add_done_callback fires from a C++ worker thread (for
1255+
# actor ObjectRefs) or a gRPC callback thread.
1256+
# _process_finished_request and decrement_queue_len_cache both
1257+
# access shared router state that is not thread-safe, so we
1258+
# schedule them on the router's event loop.
1259+
callback = partial(
1260+
self._process_finished_request,
1261+
replica.replica_id,
1262+
replica_pr.metadata.internal_request_id,
1263+
replica.actor_id,
1264+
)
1265+
result.add_done_callback(
1266+
lambda _, cb=callback: self._event_loop.call_soon_threadsafe(cb, _)
1267+
)
1268+
result.add_done_callback(
1269+
lambda _, rid=replica.replica_id: (
1270+
self._event_loop.call_soon_threadsafe(
1271+
self.request_router.decrement_queue_len_cache,
1272+
rid,
1273+
)
1274+
)
1275+
)
1276+
1277+
results.append(result)
1278+
1279+
if not results:
1280+
raise DeploymentUnavailableError(self.deployment_id)
1281+
1282+
return results
1283+
11601284
async def shutdown(self):
11611285
await self._metrics_manager.shutdown()
11621286

@@ -1301,6 +1425,16 @@ def create_task_and_setup():
13011425
self._asyncio_loop.call_soon_threadsafe(create_task_and_setup)
13021426
return concurrent_future
13031427

1428+
async def broadcast(
1429+
self,
1430+
request_meta: RequestMetadata,
1431+
*request_args,
1432+
**request_kwargs,
1433+
) -> List[ReplicaResult]:
1434+
return await self._asyncio_router.broadcast(
1435+
request_meta, *request_args, **request_kwargs
1436+
)
1437+
13041438
def shutdown(self) -> concurrent.futures.Future:
13051439
return asyncio.run_coroutine_threadsafe(
13061440
self._asyncio_router.shutdown(), loop=self._asyncio_loop
@@ -1419,5 +1553,15 @@ def assign_request(
14191553
),
14201554
)
14211555

1556+
async def broadcast(
1557+
self,
1558+
request_meta: RequestMetadata,
1559+
*request_args,
1560+
**request_kwargs,
1561+
) -> List[ReplicaResult]:
1562+
return await self._asyncio_router.broadcast(
1563+
request_meta, *request_args, **request_kwargs
1564+
)
1565+
14221566
def shutdown(self) -> asyncio.Future:
14231567
return self._asyncio_loop.create_task(self._asyncio_router.shutdown())

0 commit comments

Comments
 (0)