Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions python/ray/serve/_private/deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,7 @@ class ReplicaStateContainer:

def __init__(self):
self._replicas: Dict[ReplicaState, List[DeploymentReplica]] = defaultdict(list)
self._replica_id_index: Dict[ReplicaID, DeploymentReplica] = {}

def add(self, state: ReplicaState, replica: DeploymentReplica):
"""Add the provided replica under the provided state.
Expand All @@ -1526,6 +1527,7 @@ def add(self, state: ReplicaState, replica: DeploymentReplica):
assert isinstance(state, ReplicaState), f"Type: {type(state)}"
replica.update_state(state)
self._replicas[state].append(replica)
self._replica_id_index[replica.replica_id] = replica

def get(
self, states: Optional[List[ReplicaState]] = None
Expand All @@ -1546,6 +1548,17 @@ def get(

return sum((self._replicas[state] for state in states), [])

def get_by_id(self, replica_id: ReplicaID) -> Optional[DeploymentReplica]:
"""Get a replica by its ID in O(1) time.

Args:
replica_id: the ID of the replica to look up.

Returns:
The DeploymentReplica if found, else None.
"""
return self._replica_id_index.get(replica_id)

def pop(
self,
exclude_version: Optional[DeploymentVersion] = None,
Expand Down Expand Up @@ -1587,6 +1600,9 @@ def pop(
self._replicas[state] = remaining
replicas.extend(popped)

for replica in replicas:
self._replica_id_index.pop(replica.replica_id, None)

return replicas

def count(
Expand Down Expand Up @@ -3496,17 +3512,16 @@ def record_request_routing_info(self, info: RequestRoutingInfo) -> None:
info: RequestRoutingInfo including deployment name, replica tag,
multiplex model ids, and routing stats.
"""
# Find the replica
for replica in self._replicas.get():
if replica.replica_id == info.replica_id:
if info.multiplexed_model_ids is not None:
replica.record_multiplexed_model_ids(info.multiplexed_model_ids)
if info.routing_stats is not None:
replica.record_routing_stats(info.routing_stats)
self._request_routing_info_updated = True
return

logger.warning(f"{info.replica_id} not found.")
# O(1) lookup via replica_id index.
replica = self._replicas.get_by_id(info.replica_id)
if replica is not None:
if info.multiplexed_model_ids is not None:
replica.record_multiplexed_model_ids(info.multiplexed_model_ids)
if info.routing_stats is not None:
replica.record_routing_stats(info.routing_stats)
self._request_routing_info_updated = True
else:
logger.warning(f"{info.replica_id} not found.")

def _stop_one_running_replica_for_testing(self):
running_replicas = self._replicas.pop(states=[ReplicaState.RUNNING])
Expand Down
7 changes: 7 additions & 0 deletions python/ray/serve/tests/unit/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,13 @@ class FakeDeploymentReplica:

def __init__(self, version: DeploymentVersion):
self._version = version
self._replica_id = ReplicaID(
get_random_string(), deployment_id=DeploymentID(name="fake")
)

@property
def replica_id(self):
return self._replica_id

@property
def version(self):
Expand Down