Skip to content

Commit da4b4d3

Browse files
JoshBClemonsTarunRavikumar
authored andcommitted
[MLI-5020] Get endpoints in parallel (#723)
1 parent 0bf531b commit da4b4d3

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import List, Optional
23

34
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy
@@ -26,19 +27,35 @@ async def list_llm_model_endpoints(
2627
owner: Optional[str],
2728
name: Optional[str],
2829
order_by: Optional[ModelEndpointOrderBy],
30+
fetch_batch_size: int = 10,
2931
) -> List[ModelEndpoint]:
3032
# Will read from cache at first
3133
records = await self.model_endpoint_record_repository.list_llm_model_endpoint_records(
3234
owner=owner,
3335
name=name,
3436
order_by=order_by,
3537
)
38+
39+
# Get model endpoints in parallel
3640
endpoints: List[ModelEndpoint] = []
37-
for record in records:
38-
infra_state = await self.model_endpoint_service._get_model_endpoint_infra_state(
39-
record=record, use_cache=True
41+
for start_idx in range(0, len(records), fetch_batch_size):
42+
end_idx = min(start_idx + fetch_batch_size, len(records))
43+
record_slice = records[start_idx:end_idx]
44+
infra_states = await asyncio.gather(
45+
*[
46+
self.model_endpoint_service._get_model_endpoint_infra_state(
47+
record=record, use_cache=True
48+
)
49+
for record in record_slice
50+
]
51+
)
52+
endpoints.extend(
53+
[
54+
ModelEndpoint(record=record, infra_state=infra_state)
55+
for record, infra_state in zip(record_slice, infra_states)
56+
]
4057
)
41-
endpoints.append(ModelEndpoint(record=record, infra_state=infra_state))
58+
4259
return endpoints
4360

4461
async def get_llm_model_endpoint(self, model_endpoint_name: str) -> Optional[ModelEndpoint]:

0 commit comments

Comments
 (0)