|
10 | 10 | from fastapi import BackgroundTasks, FastAPI, HTTPException, Request |
11 | 11 | from fastapi.responses import Response, StreamingResponse |
12 | 12 | from vllm.engine.arg_utils import AsyncEngineArgs |
13 | | -from vllm.engine.async_llm_engine import AsyncLLMEngine |
| 13 | +from vllm.engine.async_llm_engine import AsyncEngineDeadError, AsyncLLMEngine |
14 | 14 | from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest |
15 | 15 | from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor |
16 | 16 | from vllm.outputs import CompletionOutput |
@@ -75,7 +75,11 @@ async def generate(request: Request) -> Response: |
75 | 75 | sampling_params.logits_processors.append(guided_decode_logit_processor) |
76 | 76 |
|
77 | 77 | request_id = random_uuid() |
78 | | - results_generator = engine.generate(prompt, sampling_params, request_id) |
| 78 | + try: |
| 79 | + results_generator = engine.generate(prompt, sampling_params, request_id) |
| 80 | + except AsyncEngineDeadError as e: |
| 81 | + print(f"The vllm engine is dead, exiting the pod: {e}") |
| 82 | + exit(1) |
79 | 83 |
|
80 | 84 | # Streaming case |
81 | 85 | async def stream_results() -> AsyncGenerator[str, None]: |
@@ -192,6 +196,7 @@ def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: |
192 | 196 |
|
193 | 197 | engine_args = AsyncEngineArgs.from_cli_args(args) |
194 | 198 | engine = AsyncLLMEngine.from_engine_args(engine_args) |
| 199 | + engine.check_health() |
195 | 200 |
|
196 | 201 | signal.signal(signal.SIGUSR1, debug) |
197 | 202 |
|
|
0 commit comments