|
| 1 | +import asyncio |
1 | 2 | from typing import List, Optional |
2 | 3 |
|
3 | 4 | from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy |
@@ -26,19 +27,35 @@ async def list_llm_model_endpoints( |
26 | 27 | owner: Optional[str], |
27 | 28 | name: Optional[str], |
28 | 29 | order_by: Optional[ModelEndpointOrderBy], |
| 30 | + fetch_batch_size: int = 10, |
29 | 31 | ) -> List[ModelEndpoint]: |
30 | 32 | # Will read from cache at first |
31 | 33 | records = await self.model_endpoint_record_repository.list_llm_model_endpoint_records( |
32 | 34 | owner=owner, |
33 | 35 | name=name, |
34 | 36 | order_by=order_by, |
35 | 37 | ) |
| 38 | + |
| 39 | + # Get model endpoints in parallel |
36 | 40 | endpoints: List[ModelEndpoint] = [] |
37 | | - for record in records: |
38 | | - infra_state = await self.model_endpoint_service._get_model_endpoint_infra_state( |
39 | | - record=record, use_cache=True |
| 41 | + for start_idx in range(0, len(records), fetch_batch_size): |
| 42 | + end_idx = min(start_idx + fetch_batch_size, len(records)) |
| 43 | + record_slice = records[start_idx:end_idx] |
| 44 | + infra_states = await asyncio.gather( |
| 45 | + *[ |
| 46 | + self.model_endpoint_service._get_model_endpoint_infra_state( |
| 47 | + record=record, use_cache=True |
| 48 | + ) |
| 49 | + for record in record_slice |
| 50 | + ] |
| 51 | + ) |
| 52 | + endpoints.extend( |
| 53 | + [ |
| 54 | + ModelEndpoint(record=record, infra_state=infra_state) |
| 55 | + for record, infra_state in zip(record_slice, infra_states) |
| 56 | + ] |
40 | 57 | ) |
41 | | - endpoints.append(ModelEndpoint(record=record, infra_state=infra_state)) |
| 58 | + |
42 | 59 | return endpoints |
43 | 60 |
|
44 | 61 | async def get_llm_model_endpoint(self, model_endpoint_name: str) -> Optional[ModelEndpoint]: |
|
0 commit comments