Skip to content

Commit 87d41c8

Browse files
authored
[BUGFIX] [FRONTEND] Correct chat logprobs (#5029)
Co-authored-by: Breno Faria <[email protected]>
1 parent e07aff9 commit 87d41c8

File tree

6 files changed

+361
-98
lines changed

6 files changed

+361
-98
lines changed

tests/async_engine/test_openapi_server_ray.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
9494
chat_completion.choices) == 1
9595
assert chat_completion.choices[0].message is not None
9696
assert chat_completion.choices[0].logprobs is not None
97-
assert chat_completion.choices[0].logprobs.top_logprobs is not None
98-
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
97+
assert chat_completion.choices[0].logprobs.content[
98+
0].top_logprobs is not None
99+
assert len(
100+
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
99101
message = chat_completion.choices[0].message
100102
assert message.content is not None and len(message.content) >= 10
101103
assert message.role == "assistant"

tests/entrypoints/test_openai_server.py

Lines changed: 181 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
184184
completion.choices[0].text) >= 5
185185

186186

187+
@pytest.mark.asyncio
188+
@pytest.mark.parametrize(
189+
# first test base model, then test loras
190+
"model_name",
191+
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
192+
)
193+
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
194+
model_name: str):
195+
# test using token IDs
196+
completion = await client.completions.create(
197+
model=MODEL_NAME,
198+
prompt=[0, 0, 0, 0, 0],
199+
max_tokens=5,
200+
temperature=0.0,
201+
logprobs=None,
202+
)
203+
choice = completion.choices[0]
204+
assert choice.logprobs is None
205+
206+
187207
@pytest.mark.asyncio
188208
@pytest.mark.parametrize(
189209
# first test base model, then test loras
@@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
203223
choice = completion.choices[0]
204224
assert choice.logprobs is not None
205225
assert choice.logprobs.token_logprobs is not None
206-
assert choice.logprobs.top_logprobs is None
226+
assert choice.logprobs.top_logprobs is not None
227+
assert len(choice.logprobs.top_logprobs[0]) <= 1
228+
229+
230+
@pytest.mark.asyncio
231+
@pytest.mark.parametrize(
232+
"model_name",
233+
[MODEL_NAME, "zephyr-lora"],
234+
)
235+
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
236+
model_name: str):
237+
# test using token IDs
238+
completion = await client.completions.create(
239+
model=MODEL_NAME,
240+
prompt=[0, 0, 0, 0, 0],
241+
max_tokens=5,
242+
temperature=0.0,
243+
logprobs=5,
244+
)
245+
choice = completion.choices[0]
246+
assert choice.logprobs is not None
247+
assert choice.logprobs.token_logprobs is not None
248+
assert choice.logprobs.top_logprobs is not None
249+
assert len(choice.logprobs.top_logprobs[0]) <= 6
250+
251+
252+
@pytest.mark.asyncio
253+
@pytest.mark.parametrize(
254+
"model_name",
255+
[MODEL_NAME, "zephyr-lora"],
256+
)
257+
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
258+
model_name: str):
259+
260+
with pytest.raises(
261+
(openai.BadRequestError, openai.APIError)): # test using token IDs
262+
await client.completions.create(
263+
model=MODEL_NAME,
264+
prompt=[0, 0, 0, 0, 0],
265+
max_tokens=5,
266+
temperature=0.0,
267+
logprobs=6,
268+
)
269+
...
270+
with pytest.raises(
271+
(openai.BadRequestError, openai.APIError)): # test using token IDs
272+
stream = await client.completions.create(
273+
model=MODEL_NAME,
274+
prompt=[0, 0, 0, 0, 0],
275+
max_tokens=5,
276+
temperature=0.0,
277+
logprobs=6,
278+
stream=True,
279+
)
280+
async for chunk in stream:
281+
...
282+
283+
# the server should still work afterwards
284+
completion = await client.completions.create(
285+
model=model_name,
286+
prompt=[0, 0, 0, 0, 0],
287+
max_tokens=5,
288+
temperature=0.0,
289+
)
290+
completion = completion.choices[0].text
291+
assert completion is not None and len(completion) >= 0
207292

208293

209294
@pytest.mark.asyncio
@@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
233318
chat_completion.choices) == 1
234319
assert chat_completion.choices[0].message is not None
235320
assert chat_completion.choices[0].logprobs is not None
236-
assert chat_completion.choices[0].logprobs.top_logprobs is not None
237-
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
321+
assert chat_completion.choices[0].logprobs.content[
322+
0].top_logprobs is not None
323+
assert len(
324+
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
238325
message = chat_completion.choices[0].message
239326
assert message.content is not None and len(message.content) >= 10
240327
assert message.role == "assistant"
@@ -251,10 +338,93 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
251338
assert message.content is not None and len(message.content) >= 0
252339

253340

341+
@pytest.mark.asyncio
342+
@pytest.mark.parametrize(
343+
# first test base model, then test loras
344+
"model_name",
345+
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
346+
)
347+
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
348+
model_name: str):
349+
messages = [{
350+
"role": "system",
351+
"content": "you are a helpful assistant"
352+
}, {
353+
"role": "user",
354+
"content": "what is 1+1?"
355+
}]
356+
357+
chat_completion = await client.chat.completions.create(model=model_name,
358+
messages=messages,
359+
max_tokens=5,
360+
temperature=0.0,
361+
logprobs=False)
362+
363+
choice = chat_completion.choices[0]
364+
assert choice.logprobs is None
365+
366+
367+
@pytest.mark.asyncio
368+
@pytest.mark.parametrize(
369+
# just test 1 lora hereafter
370+
"model_name",
371+
[MODEL_NAME, "zephyr-lora"],
372+
)
373+
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
374+
model_name: str):
375+
messages = [{
376+
"role": "system",
377+
"content": "you are a helpful assistant"
378+
}, {
379+
"role": "user",
380+
"content": "what is 1+1?"
381+
}]
382+
383+
chat_completion = await client.chat.completions.create(model=model_name,
384+
messages=messages,
385+
max_tokens=5,
386+
temperature=0.0,
387+
logprobs=True,
388+
top_logprobs=0)
389+
390+
choice = chat_completion.choices[0]
391+
assert choice.logprobs is not None
392+
assert choice.logprobs.content is not None
393+
assert len(choice.logprobs.content[0].top_logprobs) <= 1
394+
395+
396+
@pytest.mark.asyncio
397+
@pytest.mark.parametrize(
398+
"model_name",
399+
[MODEL_NAME, "zephyr-lora"],
400+
)
401+
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
402+
model_name: str):
403+
messages = [{
404+
"role": "system",
405+
"content": "you are a helpful assistant"
406+
}, {
407+
"role": "user",
408+
"content": "what is 1+1?"
409+
}]
410+
411+
chat_completion = await client.chat.completions.create(model=model_name,
412+
messages=messages,
413+
max_tokens=5,
414+
temperature=0.0,
415+
logprobs=True,
416+
top_logprobs=5)
417+
418+
choice = chat_completion.choices[0]
419+
assert choice.logprobs is not None
420+
assert choice.logprobs.content is not None
421+
assert len(choice.logprobs.content[0].top_logprobs) <= 6
422+
423+
254424
@pytest.mark.asyncio
255425
@pytest.mark.parametrize("model_name", [MODEL_NAME])
256-
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
257-
model_name: str):
426+
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
427+
model_name: str):
258428
messages = [{
259429
"role": "system",
260430
"content": "you are a helpful assistant"
@@ -263,13 +433,13 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
263433
"content": "what is 1+1?"
264434
}]
265435

266-
# Default max_logprobs is 5, so this should raise an error
436+
# Default max_logprobs is 20, so this should raise an error
267437
with pytest.raises((openai.BadRequestError, openai.APIError)):
268438
stream = await client.chat.completions.create(model=model_name,
269439
messages=messages,
270440
max_tokens=10,
271441
logprobs=True,
272-
top_logprobs=10,
442+
top_logprobs=21,
273443
stream=True)
274444
async for chunk in stream:
275445
...
@@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
279449
messages=messages,
280450
max_tokens=10,
281451
logprobs=True,
282-
top_logprobs=10,
452+
top_logprobs=30,
283453
stream=False)
284454

285-
with pytest.raises((openai.BadRequestError, openai.APIError)):
286-
stream = await client.completions.create(model=model_name,
287-
prompt="Test",
288-
max_tokens=10,
289-
logprobs=10,
290-
stream=True)
291-
async for chunk in stream:
292-
...
293-
294-
with pytest.raises(openai.BadRequestError):
295-
await client.completions.create(model=model_name,
296-
prompt="Test",
297-
max_tokens=10,
298-
logprobs=10,
299-
stream=False)
300-
301455
# the server should still work afterwards
302456
chat_completion = await client.chat.completions.create(model=model_name,
303457
messages=messages,
@@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
744898
top_logprobs=5,
745899
extra_body=dict(guided_choice=TEST_CHOICE,
746900
guided_decoding_backend=guided_decoding_backend))
747-
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
901+
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
748902

749903
# -9999.0 is the minimum logprob returned by OpenAI
750904
assert all(
751-
isinstance(logprob, float) and logprob >= -9999.0
752-
for token_dict in top_logprobs
753-
for token, logprob in token_dict.items())
905+
isinstance(token.logprob, float) and token.logprob >= -9999.0
906+
for token in top_logprobs)
754907

755908

756909
@pytest.mark.asyncio

vllm/entrypoints/openai/protocol.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,19 @@ def check_guided_decoding_count(cls, data):
250250
"('guided_json', 'guided_regex' or 'guided_choice').")
251251
return data
252252

253+
@model_validator(mode="before")
254+
@classmethod
255+
def check_logprobs(cls, data):
256+
if "top_logprobs" in data and data["top_logprobs"] is not None:
257+
if "logprobs" not in data or data["logprobs"] is False:
258+
raise ValueError(
259+
"when using `top_logprobs`, `logprobs` must be set to true."
260+
)
261+
elif not 0 <= data["top_logprobs"] <= 20:
262+
raise ValueError(
263+
"`top_logprobs` must be a value in the interval [0, 20].")
264+
return data
265+
253266

254267
class CompletionRequest(OpenAIBaseModel):
255268
# Ordered by official OpenAI API documentation
@@ -396,6 +409,15 @@ def check_guided_decoding_count(cls, data):
396409
"('guided_json', 'guided_regex' or 'guided_choice').")
397410
return data
398411

412+
@model_validator(mode="before")
413+
@classmethod
414+
def check_logprobs(cls, data):
415+
if "logprobs" in data and data[
416+
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
417+
raise ValueError(("if passed, `logprobs` must be a value",
418+
" in the interval [0, 5]."))
419+
return data
420+
399421

400422
class EmbeddingRequest(BaseModel):
401423
# Ordered by official OpenAI API documentation
@@ -415,7 +437,7 @@ def to_pooling_params(self):
415437
return PoolingParams(additional_data=self.additional_data)
416438

417439

418-
class LogProbs(OpenAIBaseModel):
440+
class CompletionLogProbs(OpenAIBaseModel):
419441
text_offset: List[int] = Field(default_factory=list)
420442
token_logprobs: List[Optional[float]] = Field(default_factory=list)
421443
tokens: List[str] = Field(default_factory=list)
@@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
425447
class CompletionResponseChoice(OpenAIBaseModel):
426448
index: int
427449
text: str
428-
logprobs: Optional[LogProbs] = None
450+
logprobs: Optional[CompletionLogProbs] = None
429451
finish_reason: Optional[str] = None
430452
stop_reason: Optional[Union[int, str]] = Field(
431453
default=None,
@@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
448470
class CompletionResponseStreamChoice(OpenAIBaseModel):
449471
index: int
450472
text: str
451-
logprobs: Optional[LogProbs] = None
473+
logprobs: Optional[CompletionLogProbs] = None
452474
finish_reason: Optional[str] = None
453475
stop_reason: Optional[Union[int, str]] = Field(
454476
default=None,
@@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel):
488510
content: str
489511

490512

513+
class ChatCompletionLogProb(OpenAIBaseModel):
514+
token: str
515+
logprob: float = -9999.0
516+
bytes: Optional[List[int]] = None
517+
518+
519+
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
520+
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
521+
522+
523+
class ChatCompletionLogProbs(OpenAIBaseModel):
524+
content: Optional[List[ChatCompletionLogProbsContent]] = None
525+
526+
491527
class ChatCompletionResponseChoice(OpenAIBaseModel):
492528
index: int
493529
message: ChatMessage
494-
logprobs: Optional[LogProbs] = None
495-
finish_reason: Optional[str] = None
530+
logprobs: Optional[ChatCompletionLogProbs] = None
531+
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
496532
stop_reason: Optional[Union[int, str]] = None
497533

498534

@@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
513549
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
514550
index: int
515551
delta: DeltaMessage
516-
logprobs: Optional[LogProbs] = None
517-
finish_reason: Optional[str] = None
552+
logprobs: Optional[ChatCompletionLogProbs] = None
553+
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
518554
stop_reason: Optional[Union[int, str]] = None
519555

520556

0 commit comments

Comments
 (0)