diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 37c7a5f90d4c..1671b712d86e 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -1515,6 +1515,7 @@ class ReplicaStateContainer: def __init__(self): self._replicas: Dict[ReplicaState, List[DeploymentReplica]] = defaultdict(list) + self._replica_id_index: Dict[ReplicaID, DeploymentReplica] = {} def add(self, state: ReplicaState, replica: DeploymentReplica): """Add the provided replica under the provided state. @@ -1526,6 +1527,7 @@ def add(self, state: ReplicaState, replica: DeploymentReplica): assert isinstance(state, ReplicaState), f"Type: {type(state)}" replica.update_state(state) self._replicas[state].append(replica) + self._replica_id_index[replica.replica_id] = replica def get( self, states: Optional[List[ReplicaState]] = None @@ -1540,12 +1542,23 @@ def get( are considered. """ if states is None: - states = ALL_REPLICA_STATES + return list(self._replica_id_index.values()) assert isinstance(states, list) return sum((self._replicas[state] for state in states), []) + def get_by_id(self, replica_id: ReplicaID) -> Optional[DeploymentReplica]: + """Get a replica by its ID in O(1) time. + + Args: + replica_id: the ID of the replica to look up. + + Returns: + The DeploymentReplica if found, else None. + """ + return self._replica_id_index.get(replica_id) + def pop( self, exclude_version: Optional[DeploymentVersion] = None, @@ -1587,6 +1600,9 @@ def pop( self._replicas[state] = remaining replicas.extend(popped) + for replica in replicas: + self._replica_id_index.pop(replica.replica_id, None) + return replicas def count( @@ -3496,17 +3512,16 @@ def record_request_routing_info(self, info: RequestRoutingInfo) -> None: info: RequestRoutingInfo including deployment name, replica tag, multiplex model ids, and routing stats. """ - # Find the replica - for replica in self._replicas.get(): - if replica.replica_id == info.replica_id: - if info.multiplexed_model_ids is not None: - replica.record_multiplexed_model_ids(info.multiplexed_model_ids) - if info.routing_stats is not None: - replica.record_routing_stats(info.routing_stats) - self._request_routing_info_updated = True - return - - logger.warning(f"{info.replica_id} not found.") + # O(1) lookup via replica_id index. + replica = self._replicas.get_by_id(info.replica_id) + if replica is not None: + if info.multiplexed_model_ids is not None: + replica.record_multiplexed_model_ids(info.multiplexed_model_ids) + if info.routing_stats is not None: + replica.record_routing_stats(info.routing_stats) + self._request_routing_info_updated = True + else: + logger.warning(f"{info.replica_id} not found.") def _stop_one_running_replica_for_testing(self): running_replicas = self._replicas.pop(states=[ReplicaState.RUNNING]) diff --git a/python/ray/serve/tests/unit/test_deployment_state.py b/python/ray/serve/tests/unit/test_deployment_state.py index 063ff17e81cc..b359596b417d 100644 --- a/python/ray/serve/tests/unit/test_deployment_state.py +++ b/python/ray/serve/tests/unit/test_deployment_state.py @@ -442,6 +442,13 @@ class FakeDeploymentReplica: def __init__(self, version: DeploymentVersion): self._version = version + self._replica_id = ReplicaID( + get_random_string(), deployment_id=DeploymentID(name="fake") + ) + + @property + def replica_id(self): + return self._replica_id @property def version(self): @@ -532,6 +539,37 @@ def test_get(self): assert c.get([ReplicaState.STARTING]) == [r1, r2] assert c.get([ReplicaState.STOPPING]) == [r3] + def test_get_by_id(self): + c = ReplicaStateContainer() + r1, r2, r3 = replica(), replica(), replica() + + c.add(ReplicaState.STARTING, r1) + c.add(ReplicaState.RUNNING, r2) + c.add(ReplicaState.STOPPING, r3) + + # Found: each replica is retrievable by its ID regardless of state. + assert c.get_by_id(r1.replica_id) is r1 + assert c.get_by_id(r2.replica_id) is r2 + assert c.get_by_id(r3.replica_id) is r3 + + # Not found: a replica ID that was never added returns None. + unknown = replica() + assert c.get_by_id(unknown.replica_id) is None + + # After pop: popped replicas are no longer in the index. + popped = c.pop(states=[ReplicaState.RUNNING]) + assert popped == [r2] + assert c.get_by_id(r2.replica_id) is None + + # Remaining replicas are still found. + assert c.get_by_id(r1.replica_id) is r1 + assert c.get_by_id(r3.replica_id) is r3 + + # Pop everything and verify the index is fully cleared. + c.pop() + assert c.get_by_id(r1.replica_id) is None + assert c.get_by_id(r3.replica_id) is None + def test_pop_basic(self): c = ReplicaStateContainer() r1, r2, r3 = replica(), replica(), replica()