Skip to content

Commit 3706618

Browse files
authored
[Frontend] Update OpenAI error response to upstream format (#22099)
Signed-off-by: Moritz Sanft <[email protected]>
1 parent cbc8457 commit 3706618

File tree

10 files changed

+73
-67
lines changed

10 files changed

+73
-67
lines changed

tests/entrypoints/openai/test_classification.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer,
121121

122122
error = classification_response.json()
123123
assert classification_response.status_code == 400
124-
assert error["object"] == "error"
125-
assert "truncate_prompt_tokens" in error["message"]
124+
assert "truncate_prompt_tokens" in error["error"]["message"]
126125

127126

128127
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@@ -137,7 +136,7 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
137136

138137
error = classification_response.json()
139138
assert classification_response.status_code == 400
140-
assert error["object"] == "error"
139+
assert "error" in error
141140

142141

143142
@pytest.mark.parametrize("model_name", [MODEL_NAME])

tests/entrypoints/openai/test_lora_resolvers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup,
160160
mock_engine.generate.assert_not_called()
161161

162162
assert isinstance(response, ErrorResponse)
163-
assert response.code == HTTPStatus.NOT_FOUND.value
164-
assert non_existent_model in response.message
163+
assert response.error.code == HTTPStatus.NOT_FOUND.value
164+
assert non_existent_model in response.error.message
165165

166166

167167
@pytest.mark.asyncio
@@ -190,8 +190,8 @@ async def test_serving_completion_resolver_add_lora_fails(
190190

191191
# Assert the correct error response
192192
assert isinstance(response, ErrorResponse)
193-
assert response.code == HTTPStatus.BAD_REQUEST.value
194-
assert invalid_model in response.message
193+
assert response.error.code == HTTPStatus.BAD_REQUEST.value
194+
assert invalid_model in response.error.message
195195

196196

197197
@pytest.mark.asyncio

tests/entrypoints/openai/test_serving_models.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ async def test_load_lora_adapter_missing_fields():
6666
request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
6767
response = await serving_models.load_lora_adapter(request)
6868
assert isinstance(response, ErrorResponse)
69-
assert response.type == "InvalidUserInput"
70-
assert response.code == HTTPStatus.BAD_REQUEST
69+
assert response.error.type == "InvalidUserInput"
70+
assert response.error.code == HTTPStatus.BAD_REQUEST
7171

7272

7373
@pytest.mark.asyncio
@@ -84,8 +84,8 @@ async def test_load_lora_adapter_duplicate():
8484
lora_path="/path/to/adapter1")
8585
response = await serving_models.load_lora_adapter(request)
8686
assert isinstance(response, ErrorResponse)
87-
assert response.type == "InvalidUserInput"
88-
assert response.code == HTTPStatus.BAD_REQUEST
87+
assert response.error.type == "InvalidUserInput"
88+
assert response.error.code == HTTPStatus.BAD_REQUEST
8989
assert len(serving_models.lora_requests) == 1
9090

9191

@@ -110,8 +110,8 @@ async def test_unload_lora_adapter_missing_fields():
110110
request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
111111
response = await serving_models.unload_lora_adapter(request)
112112
assert isinstance(response, ErrorResponse)
113-
assert response.type == "InvalidUserInput"
114-
assert response.code == HTTPStatus.BAD_REQUEST
113+
assert response.error.type == "InvalidUserInput"
114+
assert response.error.code == HTTPStatus.BAD_REQUEST
115115

116116

117117
@pytest.mark.asyncio
@@ -120,5 +120,5 @@ async def test_unload_lora_adapter_not_found():
120120
request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
121121
response = await serving_models.unload_lora_adapter(request)
122122
assert isinstance(response, ErrorResponse)
123-
assert response.type == "NotFoundError"
124-
assert response.code == HTTPStatus.NOT_FOUND
123+
assert response.error.type == "NotFoundError"
124+
assert response.error.code == HTTPStatus.NOT_FOUND

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,10 @@ async def test_non_asr_model(winning_call):
116116
file=winning_call,
117117
language="en",
118118
temperature=0.0)
119-
assert res.code == 400 and not res.text
120-
assert res.message == "The model does not support Transcriptions API"
119+
err = res.error
120+
assert err["code"] == 400 and not res.text
121+
assert err[
122+
"message"] == "The model does not support Transcriptions API"
121123

122124

123125
@pytest.mark.asyncio
@@ -133,12 +135,15 @@ async def test_completion_endpoints():
133135
"role": "system",
134136
"content": "You are a helpful assistant."
135137
}])
136-
assert res.code == 400
137-
assert res.message == "The model does not support Chat Completions API"
138+
err = res.error
139+
assert err["code"] == 400
140+
assert err[
141+
"message"] == "The model does not support Chat Completions API"
138142

139143
res = await client.completions.create(model=model_name, prompt="Hello")
140-
assert res.code == 400
141-
assert res.message == "The model does not support Completions API"
144+
err = res.error
145+
assert err["code"] == 400
146+
assert err["message"] == "The model does not support Completions API"
142147

143148

144149
@pytest.mark.asyncio

tests/entrypoints/openai/test_translation_validation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ async def test_non_asr_model(foscolo):
7373
res = await client.audio.translations.create(model=model_name,
7474
file=foscolo,
7575
temperature=0.0)
76-
assert res.code == 400 and not res.text
77-
assert res.message == "The model does not support Translations API"
76+
err = res.error
77+
assert err["code"] == 400 and not res.text
78+
assert err["message"] == "The model does not support Translations API"
7879

7980

8081
@pytest.mark.asyncio

vllm/entrypoints/openai/api_server.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@
6262
DetokenizeRequest,
6363
DetokenizeResponse,
6464
EmbeddingRequest,
65-
EmbeddingResponse, ErrorResponse,
65+
EmbeddingResponse, ErrorInfo,
66+
ErrorResponse,
6667
LoadLoRAAdapterRequest,
6768
PoolingRequest, PoolingResponse,
6869
RerankRequest, RerankResponse,
@@ -506,7 +507,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
506507

507508
if isinstance(generator, ErrorResponse):
508509
return JSONResponse(content=generator.model_dump(),
509-
status_code=generator.code)
510+
status_code=generator.error.code)
510511
elif isinstance(generator, TokenizeResponse):
511512
return JSONResponse(content=generator.model_dump())
512513

@@ -540,7 +541,7 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
540541

541542
if isinstance(generator, ErrorResponse):
542543
return JSONResponse(content=generator.model_dump(),
543-
status_code=generator.code)
544+
status_code=generator.error.code)
544545
elif isinstance(generator, DetokenizeResponse):
545546
return JSONResponse(content=generator.model_dump())
546547

@@ -556,7 +557,7 @@ async def get_tokenizer_info(raw_request: Request):
556557
"""Get comprehensive tokenizer information."""
557558
result = await tokenization(raw_request).get_tokenizer_info()
558559
return JSONResponse(content=result.model_dump(),
559-
status_code=result.code if isinstance(
560+
status_code=result.error.code if isinstance(
560561
result, ErrorResponse) else 200)
561562

562563

@@ -603,7 +604,7 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
603604

604605
if isinstance(generator, ErrorResponse):
605606
return JSONResponse(content=generator.model_dump(),
606-
status_code=generator.code)
607+
status_code=generator.error.code)
607608
elif isinstance(generator, ResponsesResponse):
608609
return JSONResponse(content=generator.model_dump())
609610
return StreamingResponse(content=generator, media_type="text/event-stream")
@@ -620,7 +621,7 @@ async def retrieve_responses(response_id: str, raw_request: Request):
620621

621622
if isinstance(response, ErrorResponse):
622623
return JSONResponse(content=response.model_dump(),
623-
status_code=response.code)
624+
status_code=response.error.code)
624625
return JSONResponse(content=response.model_dump())
625626

626627

@@ -635,7 +636,7 @@ async def cancel_responses(response_id: str, raw_request: Request):
635636

636637
if isinstance(response, ErrorResponse):
637638
return JSONResponse(content=response.model_dump(),
638-
status_code=response.code)
639+
status_code=response.error.code)
639640
return JSONResponse(content=response.model_dump())
640641

641642

@@ -670,7 +671,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
670671

671672
if isinstance(generator, ErrorResponse):
672673
return JSONResponse(content=generator.model_dump(),
673-
status_code=generator.code)
674+
status_code=generator.error.code)
674675

675676
elif isinstance(generator, ChatCompletionResponse):
676677
return JSONResponse(content=generator.model_dump())
@@ -715,7 +716,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
715716

716717
if isinstance(generator, ErrorResponse):
717718
return JSONResponse(content=generator.model_dump(),
718-
status_code=generator.code)
719+
status_code=generator.error.code)
719720
elif isinstance(generator, CompletionResponse):
720721
return JSONResponse(content=generator.model_dump())
721722

@@ -744,7 +745,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
744745

745746
if isinstance(generator, ErrorResponse):
746747
return JSONResponse(content=generator.model_dump(),
747-
status_code=generator.code)
748+
status_code=generator.error.code)
748749
elif isinstance(generator, EmbeddingResponse):
749750
return JSONResponse(content=generator.model_dump())
750751

@@ -772,7 +773,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
772773
generator = await handler.create_pooling(request, raw_request)
773774
if isinstance(generator, ErrorResponse):
774775
return JSONResponse(content=generator.model_dump(),
775-
status_code=generator.code)
776+
status_code=generator.error.code)
776777
elif isinstance(generator, PoolingResponse):
777778
return JSONResponse(content=generator.model_dump())
778779

@@ -792,7 +793,7 @@ async def create_classify(request: ClassificationRequest,
792793
generator = await handler.create_classify(request, raw_request)
793794
if isinstance(generator, ErrorResponse):
794795
return JSONResponse(content=generator.model_dump(),
795-
status_code=generator.code)
796+
status_code=generator.error.code)
796797

797798
elif isinstance(generator, ClassificationResponse):
798799
return JSONResponse(content=generator.model_dump())
@@ -821,7 +822,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
821822
generator = await handler.create_score(request, raw_request)
822823
if isinstance(generator, ErrorResponse):
823824
return JSONResponse(content=generator.model_dump(),
824-
status_code=generator.code)
825+
status_code=generator.error.code)
825826
elif isinstance(generator, ScoreResponse):
826827
return JSONResponse(content=generator.model_dump())
827828

@@ -881,7 +882,7 @@ async def create_transcriptions(raw_request: Request,
881882

882883
if isinstance(generator, ErrorResponse):
883884
return JSONResponse(content=generator.model_dump(),
884-
status_code=generator.code)
885+
status_code=generator.error.code)
885886

886887
elif isinstance(generator, TranscriptionResponse):
887888
return JSONResponse(content=generator.model_dump())
@@ -922,7 +923,7 @@ async def create_translations(request: Annotated[TranslationRequest,
922923

923924
if isinstance(generator, ErrorResponse):
924925
return JSONResponse(content=generator.model_dump(),
925-
status_code=generator.code)
926+
status_code=generator.error.code)
926927

927928
elif isinstance(generator, TranslationResponse):
928929
return JSONResponse(content=generator.model_dump())
@@ -950,7 +951,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
950951
generator = await handler.do_rerank(request, raw_request)
951952
if isinstance(generator, ErrorResponse):
952953
return JSONResponse(content=generator.model_dump(),
953-
status_code=generator.code)
954+
status_code=generator.error.code)
954955
elif isinstance(generator, RerankResponse):
955956
return JSONResponse(content=generator.model_dump())
956957

@@ -1175,7 +1176,7 @@ async def invocations(raw_request: Request):
11751176
msg = ("Cannot find suitable handler for request. "
11761177
f"Expected one of: {type_names}")
11771178
res = base(raw_request).create_error_response(message=msg)
1178-
return JSONResponse(content=res.model_dump(), status_code=res.code)
1179+
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
11791180

11801181

11811182
if envs.VLLM_TORCH_PROFILER_DIR:
@@ -1211,7 +1212,7 @@ async def load_lora_adapter(request: LoadLoRAAdapterRequest,
12111212
response = await handler.load_lora_adapter(request)
12121213
if isinstance(response, ErrorResponse):
12131214
return JSONResponse(content=response.model_dump(),
1214-
status_code=response.code)
1215+
status_code=response.error.code)
12151216

12161217
return Response(status_code=200, content=response)
12171218

@@ -1223,7 +1224,7 @@ async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
12231224
response = await handler.unload_lora_adapter(request)
12241225
if isinstance(response, ErrorResponse):
12251226
return JSONResponse(content=response.model_dump(),
1226-
status_code=response.code)
1227+
status_code=response.error.code)
12271228

12281229
return Response(status_code=200, content=response)
12291230

@@ -1502,9 +1503,10 @@ def build_app(args: Namespace) -> FastAPI:
15021503

15031504
@app.exception_handler(HTTPException)
15041505
async def http_exception_handler(_: Request, exc: HTTPException):
1505-
err = ErrorResponse(message=exc.detail,
1506+
err = ErrorResponse(
1507+
error=ErrorInfo(message=exc.detail,
15061508
type=HTTPStatus(exc.status_code).phrase,
1507-
code=exc.status_code)
1509+
code=exc.status_code))
15081510
return JSONResponse(err.model_dump(), status_code=exc.status_code)
15091511

15101512
@app.exception_handler(RequestValidationError)
@@ -1518,9 +1520,9 @@ async def validation_exception_handler(_: Request,
15181520
else:
15191521
message = exc_str
15201522

1521-
err = ErrorResponse(message=message,
1522-
type=HTTPStatus.BAD_REQUEST.phrase,
1523-
code=HTTPStatus.BAD_REQUEST)
1523+
err = ErrorResponse(error=ErrorInfo(message=message,
1524+
type=HTTPStatus.BAD_REQUEST.phrase,
1525+
code=HTTPStatus.BAD_REQUEST))
15241526
return JSONResponse(err.model_dump(),
15251527
status_code=HTTPStatus.BAD_REQUEST)
15261528

vllm/entrypoints/openai/protocol.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,17 @@ def __log_extra_fields__(cls, data, handler):
7878
return result
7979

8080

81-
class ErrorResponse(OpenAIBaseModel):
82-
object: str = "error"
81+
class ErrorInfo(OpenAIBaseModel):
8382
message: str
8483
type: str
8584
param: Optional[str] = None
8685
code: int
8786

8887

88+
class ErrorResponse(OpenAIBaseModel):
89+
error: ErrorInfo
90+
91+
8992
class ModelPermission(OpenAIBaseModel):
9093
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
9194
object: str = "model_permission"

vllm/entrypoints/openai/run_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ async def run_request(serving_engine_func: Callable,
302302
id=f"vllm-{random_uuid()}",
303303
custom_id=request.custom_id,
304304
response=BatchResponseData(
305-
status_code=response.code,
305+
status_code=response.error.code,
306306
request_id=f"vllm-batch-{random_uuid()}"),
307307
error=response,
308308
)

0 commit comments

Comments
 (0)