Skip to content

Commit 80dd037

Browse files
authored
[fix] Sampling Parameters related improvements (#80)
1 parent d061556 commit 80dd037

File tree

5 files changed

+373
-52
lines changed

5 files changed

+373
-52
lines changed

ci/L0_backend_vllm/accuracy_test/accuracy_test.py

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Redistribution and use in source and binary forms, with or without
44
# modification, are permitted provided that the following conditions
@@ -26,6 +26,7 @@
2626

2727
import argparse
2828
import asyncio
29+
import json
2930
import pickle
3031
import sys
3132
import unittest
@@ -36,6 +37,7 @@
3637
from vllm import SamplingParams
3738
from vllm.engine.arg_utils import AsyncEngineArgs
3839
from vllm.engine.async_llm_engine import AsyncLLMEngine
40+
from vllm.sampling_params import GuidedDecodingParams
3941
from vllm.utils import random_uuid
4042

4143
sys.path.append("../../common")
@@ -53,14 +55,22 @@
5355
"The future of AI is",
5456
]
5557

58+
GUIDED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. "]
59+
5660
SAMPLING_PARAMETERS = {"temperature": 0, "top_p": 1}
5761

5862

59-
async def generate_python_vllm_output(prompt, llm_engine):
63+
async def generate_python_vllm_output(
64+
prompt,
65+
llm_engine,
66+
sampling_params=SamplingParams(**SAMPLING_PARAMETERS),
67+
guided_generation=None,
68+
):
6069
request_id = random_uuid()
61-
sampling_params = SamplingParams(**SAMPLING_PARAMETERS)
6270
python_vllm_output = None
6371
last_output = None
72+
if guided_generation:
73+
sampling_params.guided_decoding = guided_generation
6474

6575
async for vllm_output in llm_engine.generate(prompt, sampling_params, request_id):
6676
last_output = vllm_output
@@ -69,24 +79,28 @@ async def generate_python_vllm_output(prompt, llm_engine):
6979
python_vllm_output = [
7080
(prompt + output.text).encode("utf-8") for output in last_output.outputs
7181
]
72-
7382
return python_vllm_output
7483

7584

76-
def prepare_vllm_baseline_outputs():
85+
def prepare_vllm_baseline_outputs(
86+
export_file="vllm_baseline_output.pkl", prompts=PROMPTS, guided_generation=None
87+
):
7788
"""
7889
Helper function that starts async vLLM engine and generates output for each
79-
prompt in `PROMPTS`. Saves resulted baselines in `vllm_baseline_output.pkl`
90+
prompt in `prompts`. Saves resulted baselines in `vllm_baseline_output.pkl`
8091
for further use.
8192
"""
8293
llm_engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**VLLM_ENGINE_CONFIG))
8394
python_vllm_output = []
84-
for i in range(len(PROMPTS)):
95+
for i in range(len(prompts)):
8596
python_vllm_output.extend(
86-
asyncio.run(generate_python_vllm_output(PROMPTS[i], llm_engine))
97+
asyncio.run(
98+
generate_python_vllm_output(
99+
prompts[i], llm_engine, guided_generation=guided_generation
100+
)
101+
)
87102
)
88-
89-
with open("vllm_baseline_output.pkl", "wb") as f:
103+
with open(export_file, "wb") as f:
90104
pickle.dump(python_vllm_output, f)
91105

92106
return
@@ -96,6 +110,9 @@ class VLLMTritonAccuracyTest(TestResultCollector):
96110
def setUp(self):
97111
self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
98112
self.vllm_model_name = "vllm_opt"
113+
114+
def test_vllm_model(self):
115+
# Reading and verifying baseline data
99116
self.python_vllm_output = []
100117
with open("vllm_baseline_output.pkl", "rb") as f:
101118
self.python_vllm_output = pickle.load(f)
@@ -116,11 +133,9 @@ def setUp(self):
116133
),
117134
)
118135

119-
def test_vllm_model(self):
120136
user_data = UserData()
121137
stream = False
122138
triton_vllm_output = []
123-
124139
self.triton_client.start_stream(callback=partial(callback, user_data))
125140
for i in range(len(PROMPTS)):
126141
request_data = create_vllm_request(
@@ -131,7 +146,7 @@ def test_vllm_model(self):
131146
request_id=request_data["request_id"],
132147
inputs=request_data["inputs"],
133148
outputs=request_data["outputs"],
134-
parameters=SAMPLING_PARAMETERS,
149+
parameters=request_data["parameters"],
135150
)
136151

137152
for i in range(len(PROMPTS)):
@@ -146,6 +161,63 @@ def test_vllm_model(self):
146161
self.triton_client.stop_stream()
147162
self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort())
148163

164+
def test_guided_decoding(self):
165+
# Reading and verifying baseline data
166+
self.python_vllm_output = []
167+
with open("vllm_guided_baseline_output.pkl", "rb") as f:
168+
self.python_vllm_output = pickle.load(f)
169+
170+
self.assertNotEqual(
171+
self.python_vllm_output,
172+
[],
173+
"Loaded baseline outputs' list should not be empty",
174+
)
175+
self.assertIsNotNone(
176+
self.python_vllm_output, "Loaded baseline outputs' list should not be None"
177+
)
178+
self.assertEqual(
179+
len(self.python_vllm_output),
180+
len(GUIDED_PROMPTS),
181+
"Unexpected number of baseline outputs loaded, expected {}, but got {}".format(
182+
len(GUIDED_PROMPTS), len(self.python_vllm_output)
183+
),
184+
)
185+
186+
user_data = UserData()
187+
stream = False
188+
triton_vllm_output = []
189+
190+
self.triton_client.start_stream(callback=partial(callback, user_data))
191+
sampling_params = SAMPLING_PARAMETERS
192+
guided_decoding_params = {
193+
"choice": ["Positive", "Negative"],
194+
"backend": "outlines",
195+
}
196+
sampling_params["guided_decoding"] = json.dumps(guided_decoding_params)
197+
for i in range(len(GUIDED_PROMPTS)):
198+
request_data = create_vllm_request(
199+
GUIDED_PROMPTS[i], i, stream, sampling_params, self.vllm_model_name
200+
)
201+
self.triton_client.async_stream_infer(
202+
model_name=self.vllm_model_name,
203+
request_id=request_data["request_id"],
204+
inputs=request_data["inputs"],
205+
outputs=request_data["outputs"],
206+
parameters=request_data["parameters"],
207+
)
208+
209+
for i in range(len(GUIDED_PROMPTS)):
210+
result = user_data._completed_requests.get()
211+
self.assertIsNot(type(result), InferenceServerException, str(result))
212+
213+
output = result.as_numpy("text_output")
214+
self.assertIsNotNone(output, "`text_output` should not be None")
215+
216+
triton_vllm_output.extend(output)
217+
218+
self.triton_client.stop_stream()
219+
self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort())
220+
149221
def tearDown(self):
150222
self.triton_client.close()
151223

@@ -159,9 +231,29 @@ def tearDown(self):
159231
default=False,
160232
help="Generates baseline output for accuracy tests",
161233
)
234+
parser.add_argument(
235+
"--generate-guided-baseline",
236+
action="store_true",
237+
required=False,
238+
default=False,
239+
help="Generates baseline output for accuracy tests",
240+
)
162241
FLAGS = parser.parse_args()
163242
if FLAGS.generate_baseline:
164243
prepare_vllm_baseline_outputs()
165244
exit(0)
166245

246+
if FLAGS.generate_guided_baseline:
247+
guided_decoding_params = {
248+
"choice": ["Positive", "Negative"],
249+
"backend": "outlines",
250+
}
251+
guided_generation = GuidedDecodingParams(**guided_decoding_params)
252+
prepare_vllm_baseline_outputs(
253+
export_file="vllm_guided_baseline_output.pkl",
254+
prompts=GUIDED_PROMPTS,
255+
guided_generation=guided_generation,
256+
)
257+
exit(0)
258+
167259
unittest.main()

ci/L0_backend_vllm/accuracy_test/test.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -37,7 +37,7 @@ TEST_RESULT_FILE='test_results.txt'
3737
CLIENT_PY="./accuracy_test.py"
3838
SAMPLE_MODELS_REPO="../../../samples/model_repository"
3939
VLLM_ENGINE_LOG="vllm_engine.log"
40-
EXPECTED_NUM_TESTS=1
40+
EXPECTED_NUM_TESTS=2
4141

4242
rm -rf models && mkdir -p models
4343
cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt
@@ -50,6 +50,10 @@ set +e
5050
# memory issues: https://github.com/vllm-project/vllm/issues/2248
5151
python3 $CLIENT_PY --generate-baseline >> $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
5252
wait $BASELINE_PID
53+
54+
python3 $CLIENT_PY --generate-guided-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
55+
wait $BASELINE_PID
56+
5357
set -e
5458

5559
run_server

0 commit comments

Comments
 (0)