Skip to content

Commit 1e2675e

Browse files
committed
Add logprobs additional output
1 parent 2e1a223 commit 1e2675e

File tree

2 files changed

+127
-44
lines changed

2 files changed

+127
-44
lines changed

ci/L0_additional_outputs_vllm/additional_outputs_test.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,20 @@ class TestAdditionalOutputs:
3737
_sampling_parameters = {"temperature": "0", "top_p": "1"}
3838
_prompt = "In this example,"
3939

40+
def _get_sampling_parameters(self, logprobs=None):
41+
sampling_parameters = self._sampling_parameters.copy()
42+
if logprobs is not None:
43+
sampling_parameters["logprobs"] = logprobs
44+
return sampling_parameters
45+
4046
def _get_inputs(
4147
self,
4248
prompt,
4349
stream=True,
4450
sampling_parameters=None,
4551
return_finish_reason=None,
4652
return_cumulative_logprob=None,
53+
return_logprobs=None,
4754
return_num_input_tokens=None,
4855
return_num_output_tokens=None,
4956
):
@@ -77,6 +84,10 @@ def _get_inputs(
7784
np.array([return_cumulative_logprob], dtype=bool)
7885
)
7986

87+
if return_logprobs is not None:
88+
inputs.append(grpcclient.InferInput("return_logprobs", [1], "BOOL"))
89+
inputs[-1].set_data_from_numpy(np.array([return_logprobs], dtype=bool))
90+
8091
if return_num_input_tokens is not None:
8192
inputs.append(grpcclient.InferInput("return_num_input_tokens", [1], "BOOL"))
8293
inputs[-1].set_data_from_numpy(
@@ -96,12 +107,12 @@ def _get_inputs(
96107
def _callback(self, result, error):
97108
self._responses.append({"result": result, "error": error})
98109

99-
def _llm_infer(self, inputs):
110+
def _llm_infer(self, inputs, sampling_parameters):
100111
self._responses = []
101112
with grpcclient.InferenceServerClient(self._grpc_url) as client:
102113
client.start_stream(self._callback)
103114
client.async_stream_infer(
104-
self._model_name, inputs=inputs, parameters=self._sampling_parameters
115+
self._model_name, inputs=inputs, parameters=sampling_parameters
105116
)
106117
client.stop_stream()
107118
assert len(self._responses) > 0
@@ -142,6 +153,51 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob):
142153
assert cumulative_logprob != prev_cumulative_logprob
143154
prev_cumulative_logprob = cumulative_logprob
144155

156+
def _assert_logprobs(
157+
self, stream, sampling_parameters, return_logprobs, return_num_output_tokens
158+
):
159+
for response in self._responses:
160+
result, error = response["result"], response["error"]
161+
assert error is None
162+
logprobs_np = result.as_numpy(name="logprobs")
163+
if return_logprobs is None or return_logprobs == False:
164+
assert logprobs_np is None
165+
continue
166+
logprobs = json.loads(logprobs_np[0].decode("utf-8"))
167+
if "logprobs" not in sampling_parameters:
168+
assert logprobs is None
169+
continue
170+
assert isinstance(logprobs, list)
171+
assert len(logprobs) >= 1
172+
if return_num_output_tokens == True:
173+
num_output_tokens = result.as_numpy(name="num_output_tokens")[0].astype(
174+
int
175+
)
176+
assert len(logprobs) == num_output_tokens
177+
text_output_logprobs = ""
178+
for logprobs_d in logprobs:
179+
assert isinstance(logprobs_d, dict)
180+
assert len(logprobs_d) >= 1
181+
assert len(logprobs_d) <= sampling_parameters["logprobs"] + 1
182+
rank_one_found = False
183+
for token_id, logprob_d in logprobs_d.items():
184+
assert isinstance(token_id, str)
185+
assert len(logprob_d) == 3
186+
assert isinstance(logprob_d["logprob"], float)
187+
assert isinstance(logprob_d["rank"], int)
188+
assert isinstance(logprob_d["decoded_token"], str)
189+
if logprob_d["rank"] == 1:
190+
assert not rank_one_found
191+
rank_one_found = True
192+
text_output_logprobs += logprob_d["decoded_token"]
193+
assert rank_one_found
194+
text_output = result.as_numpy(name="text_output")[0].decode("utf-8")
195+
if not stream:
196+
# given exclude_input_in_output is not set, prepend_input is True if not
197+
# streaming and False if streaming
198+
text_output_logprobs = self._prompt + text_output_logprobs
199+
assert text_output_logprobs == text_output
200+
145201
def _assert_num_input_tokens(self, return_num_input_tokens):
146202
for response in self._responses:
147203
result, error = response["result"], response["error"]
@@ -163,50 +219,42 @@ def _assert_num_output_tokens(self, return_num_output_tokens):
163219
assert num_output_tokens_np is None
164220
continue
165221
num_output_tokens = num_output_tokens_np[0].astype(int)
166-
# TODO: vLLM may return token ids identical to the previous one when
167-
# streaming, for example:
168-
#
169-
# prev: None
170-
# curr: text=' the', token_ids=array('l', [5])
171-
#
172-
# prev: text=' the', token_ids=array('l', [5, 1385])
173-
# curr: text=' the term', token_ids=array('l', [5, 1385])
174-
#
175-
# prev: text=' the term', token_ids=array('l', [5, 1385, 44])
176-
# curr: text=' the term', token_ids=array('l', [5, 1385, 44])
177-
#
178-
# prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48])
179-
# curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
180-
#
181-
# If this is no longer the case in a future release, change the assert
182-
# to assert num_output_tokens > 0.
183-
assert num_output_tokens >= 0
222+
assert num_output_tokens > 0
184223

185224
@pytest.mark.parametrize("stream", [True, False])
186225
@pytest.mark.parametrize("return_finish_reason", [None, True, False])
187226
@pytest.mark.parametrize("return_cumulative_logprob", [None, True, False])
227+
@pytest.mark.parametrize("logprobs", [None, 0, 2])
228+
@pytest.mark.parametrize("return_logprobs", [None, True, False])
188229
@pytest.mark.parametrize("return_num_input_tokens", [None, True, False])
189230
@pytest.mark.parametrize("return_num_output_tokens", [None, True, False])
190231
def test_additional_outputs(
191232
self,
192233
stream,
193234
return_finish_reason,
194235
return_cumulative_logprob,
236+
logprobs,
237+
return_logprobs,
195238
return_num_input_tokens,
196239
return_num_output_tokens,
197240
):
241+
sampling_parameters = self._get_sampling_parameters(logprobs=logprobs)
198242
inputs = self._get_inputs(
199243
self._prompt,
200244
stream=stream,
201-
sampling_parameters=self._sampling_parameters,
245+
sampling_parameters=sampling_parameters,
202246
return_finish_reason=return_finish_reason,
203247
return_cumulative_logprob=return_cumulative_logprob,
248+
return_logprobs=return_logprobs,
204249
return_num_input_tokens=return_num_input_tokens,
205250
return_num_output_tokens=return_num_output_tokens,
206251
)
207-
self._llm_infer(inputs)
252+
self._llm_infer(inputs, sampling_parameters)
208253
self._assert_text_output_valid()
209254
self._assert_finish_reason(return_finish_reason)
210255
self._assert_cumulative_logprob(return_cumulative_logprob)
256+
self._assert_logprobs(
257+
stream, sampling_parameters, return_logprobs, return_num_output_tokens
258+
)
211259
self._assert_num_input_tokens(return_num_input_tokens)
212260
self._assert_num_output_tokens(return_num_output_tokens)

src/model.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config):
104104
"dims": [1],
105105
"optional": True,
106106
},
107+
{
108+
"name": "return_logprobs",
109+
"data_type": "TYPE_BOOL",
110+
"dims": [1],
111+
"optional": True,
112+
},
107113
{
108114
"name": "return_num_input_tokens",
109115
"data_type": "TYPE_BOOL",
@@ -131,6 +137,7 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config):
131137
{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]},
132138
{"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]},
133139
{"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]},
140+
{"name": "logprobs", "data_type": "TYPE_STRING", "dims": [-1]},
134141
{"name": "num_input_tokens", "data_type": "TYPE_UINT32", "dims": [1]},
135142
{"name": "num_output_tokens", "data_type": "TYPE_UINT32", "dims": [-1]},
136143
]
@@ -388,6 +395,7 @@ def _get_input_tensors(self, request):
388395
additional_outputs = {
389396
"return_finish_reason": None,
390397
"return_cumulative_logprob": None,
398+
"return_logprobs": None,
391399
"return_num_input_tokens": None,
392400
"return_num_output_tokens": None,
393401
}
@@ -455,26 +463,27 @@ def response_loop(self):
455463
self.ongoing_request_count -= 1
456464

457465
def _create_response(
458-
self, prev_request_output, request_output, prepend_input, additional_outputs
466+
self, request_output_state, request_output, prepend_input, additional_outputs
459467
):
460468
output_tensors = []
461469

462470
# text_output
463471
prepend_prompt = ""
464-
if prev_request_output is None:
472+
if "prev_lens_text_output" not in request_output_state:
465473
# this is the first response
466474
if prepend_input:
467475
prepend_prompt = request_output.prompt
468-
prev_lens = [0] * len(request_output.outputs)
469-
else:
470-
# this is a subsequent response
471-
prev_lens = [
472-
len(prev_output.text) for prev_output in prev_request_output.outputs
473-
]
476+
request_output_state["prev_lens_text_output"] = [0] * len(
477+
request_output.outputs
478+
)
479+
prev_lens = request_output_state["prev_lens_text_output"]
474480
text_output = [
475481
(prepend_prompt + output.text[prev_len:]).encode("utf-8")
476482
for output, prev_len in zip(request_output.outputs, prev_lens)
477483
]
484+
request_output_state["prev_lens_text_output"] = [
485+
len(output.text) for output in request_output.outputs
486+
]
478487
output_tensors.append(
479488
pb_utils.Tensor(
480489
"text_output", np.asarray(text_output, dtype=self.output_dtype)
@@ -504,6 +513,35 @@ def _create_response(
504513
)
505514
)
506515

516+
# logprobs
517+
if additional_outputs["return_logprobs"]:
518+
if "prev_lens_logprobs" not in request_output_state:
519+
request_output_state["prev_lens_logprobs"] = [0] * len(
520+
request_output.outputs
521+
)
522+
logprobs = []
523+
for i in range(len(request_output.outputs)):
524+
output = request_output.outputs[i]
525+
if output.logprobs is None:
526+
logprobs.append("null".encode("utf-8"))
527+
continue
528+
prev_len = request_output_state["prev_lens_logprobs"][i]
529+
request_output_state["prev_lens_logprobs"][i] = len(output.logprobs)
530+
logprobs_py = []
531+
for logprob_d_vllm in output.logprobs[prev_len:]:
532+
logprob_d_py = {}
533+
for token_id, logprob_vllm in logprob_d_vllm.items():
534+
logprob_d_py[token_id] = {
535+
"logprob": logprob_vllm.logprob,
536+
"rank": logprob_vllm.rank,
537+
"decoded_token": logprob_vllm.decoded_token,
538+
}
539+
logprobs_py.append(logprob_d_py)
540+
logprobs.append(json.dumps(logprobs_py).encode("utf-8"))
541+
output_tensors.append(
542+
pb_utils.Tensor("logprobs", np.asarray(logprobs, dtype=np.object_))
543+
)
544+
507545
# num_input_tokens
508546
if additional_outputs["return_num_input_tokens"]:
509547
num_input_tokens = len(request_output.prompt_token_ids)
@@ -515,19 +553,18 @@ def _create_response(
515553

516554
# num_output_tokens
517555
if additional_outputs["return_num_output_tokens"]:
518-
if prev_request_output is None:
519-
# this is the first response
520-
prev_lens = [0] * len(request_output.outputs)
521-
else:
522-
# this is a subsequent response
523-
prev_lens = [
524-
len(prev_output.token_ids)
525-
for prev_output in prev_request_output.outputs
526-
]
556+
if "prev_lens_num_output_tokens" not in request_output_state:
557+
request_output_state["prev_lens_num_output_tokens"] = [0] * len(
558+
request_output.outputs
559+
)
560+
prev_lens = request_output_state["prev_lens_num_output_tokens"]
527561
num_output_tokens = [
528562
(len(output.token_ids) - prev_len)
529563
for output, prev_len in zip(request_output.outputs, prev_lens)
530564
]
565+
request_output_state["prev_lens_num_output_tokens"] = [
566+
len(output.token_ids) for output in request_output.outputs
567+
]
531568
output_tensors.append(
532569
pb_utils.Tensor(
533570
"num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32)
@@ -572,7 +609,7 @@ async def generate(self, request):
572609
request_id, prompt, sampling_params, lora_request=lora_request
573610
)
574611

575-
prev_request_output = None
612+
request_output_state = {}
576613
async for request_output in response_iterator:
577614
# Cancellation state will be checked by the response loop and written to
578615
# the response state if streaming. If not streaming, cancellation state
@@ -605,7 +642,7 @@ async def generate(self, request):
605642
# Send each response if streaming.
606643
if stream:
607644
response = self._create_response(
608-
prev_request_output,
645+
request_output_state,
609646
request_output,
610647
prepend_input=False,
611648
additional_outputs=additional_outputs,
@@ -617,13 +654,11 @@ async def generate(self, request):
617654
decrement_ongoing_request_count = False
618655
self._response_queue.put_nowait((response_state, response, flags))
619656

620-
prev_request_output = request_output
621-
622657
# Send the last response which contains all the outputs if not streaming.
623658
if not stream:
624659
response_sender.send(
625660
self._create_response(
626-
prev_request_output=None,
661+
request_output_state={},
627662
request_output=request_output,
628663
prepend_input=prepend_input,
629664
additional_outputs=additional_outputs,

0 commit comments

Comments
 (0)