Skip to content

Commit 6287bf1

Browse files
authored
ci: Deprecate vllm "guided_decoding" with "structured_outputs" (#109)
1 parent d91c406 commit 6287bf1

File tree

3 files changed

+27
-27
lines changed

3 files changed

+27
-27
lines changed

ci/L0_backend_vllm/accuracy_test/accuracy_test.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from vllm import SamplingParams
3838
from vllm.engine.arg_utils import AsyncEngineArgs
3939
from vllm.engine.async_llm_engine import AsyncLLMEngine
40-
from vllm.sampling_params import GuidedDecodingParams
40+
from vllm.sampling_params import StructuredOutputsParams
4141
from vllm.utils import random_uuid
4242

4343
sys.path.append("../../common")
@@ -55,7 +55,7 @@
5555
"The future of AI is",
5656
]
5757

58-
GUIDED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. "]
58+
STRUCTURED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. "]
5959

6060
SAMPLING_PARAMETERS = {"temperature": 0, "top_p": 1}
6161

@@ -64,13 +64,13 @@ async def generate_python_vllm_output(
6464
prompt,
6565
llm_engine,
6666
sampling_params=SamplingParams(**SAMPLING_PARAMETERS),
67-
guided_generation=None,
67+
structured_generation=None,
6868
):
6969
request_id = random_uuid()
7070
python_vllm_output = None
7171
last_output = None
72-
if guided_generation:
73-
sampling_params.guided_decoding = guided_generation
72+
if structured_generation:
73+
sampling_params.structured_outputs = structured_generation
7474

7575
async for vllm_output in llm_engine.generate(prompt, sampling_params, request_id):
7676
last_output = vllm_output
@@ -83,7 +83,7 @@ async def generate_python_vllm_output(
8383

8484

8585
async def prepare_vllm_baseline_outputs(
86-
export_file="vllm_baseline_output.pkl", prompts=PROMPTS, guided_generation=None
86+
export_file="vllm_baseline_output.pkl", prompts=PROMPTS, structured_generation=None
8787
):
8888
"""
8989
Helper function that starts async vLLM engine and generates output for each
@@ -94,7 +94,7 @@ async def prepare_vllm_baseline_outputs(
9494
python_vllm_output = []
9595
for i in range(len(prompts)):
9696
output = await generate_python_vllm_output(
97-
prompts[i], llm_engine, guided_generation=guided_generation
97+
prompts[i], llm_engine, structured_generation=structured_generation
9898
)
9999
if output:
100100
python_vllm_output.extend(output)
@@ -160,10 +160,10 @@ def test_vllm_model(self):
160160
self.triton_client.stop_stream()
161161
self.assertEqual(self.python_vllm_output.sort(), triton_vllm_output.sort())
162162

163-
def test_guided_decoding(self):
163+
def test_structured_outputs(self):
164164
# Reading and verifying baseline data
165165
self.python_vllm_output = []
166-
with open("vllm_guided_baseline_output.pkl", "rb") as f:
166+
with open("vllm_structured_baseline_output.pkl", "rb") as f:
167167
self.python_vllm_output = pickle.load(f)
168168

169169
self.assertNotEqual(
@@ -176,9 +176,9 @@ def test_guided_decoding(self):
176176
)
177177
self.assertEqual(
178178
len(self.python_vllm_output),
179-
len(GUIDED_PROMPTS),
179+
len(STRUCTURED_PROMPTS),
180180
"Unexpected number of baseline outputs loaded, expected {}, but got {}".format(
181-
len(GUIDED_PROMPTS), len(self.python_vllm_output)
181+
len(STRUCTURED_PROMPTS), len(self.python_vllm_output)
182182
),
183183
)
184184

@@ -188,13 +188,13 @@ def test_guided_decoding(self):
188188

189189
self.triton_client.start_stream(callback=partial(callback, user_data))
190190
sampling_params = SAMPLING_PARAMETERS
191-
guided_decoding_params = {
191+
structured_outputs_params = {
192192
"choice": ["Positive", "Negative"],
193193
}
194-
sampling_params["guided_decoding"] = json.dumps(guided_decoding_params)
195-
for i in range(len(GUIDED_PROMPTS)):
194+
sampling_params["structured_outputs"] = json.dumps(structured_outputs_params)
195+
for i in range(len(STRUCTURED_PROMPTS)):
196196
request_data = create_vllm_request(
197-
GUIDED_PROMPTS[i], i, stream, sampling_params, self.vllm_model_name
197+
STRUCTURED_PROMPTS[i], i, stream, sampling_params, self.vllm_model_name
198198
)
199199
self.triton_client.async_stream_infer(
200200
model_name=self.vllm_model_name,
@@ -204,7 +204,7 @@ def test_guided_decoding(self):
204204
parameters=request_data["parameters"],
205205
)
206206

207-
for i in range(len(GUIDED_PROMPTS)):
207+
for i in range(len(STRUCTURED_PROMPTS)):
208208
result = user_data._completed_requests.get()
209209
self.assertIsNot(type(result), InferenceServerException, str(result))
210210

@@ -230,7 +230,7 @@ def tearDown(self):
230230
help="Generates baseline output for accuracy tests",
231231
)
232232
parser.add_argument(
233-
"--generate-guided-baseline",
233+
"--generate-structured-baseline",
234234
action="store_true",
235235
required=False,
236236
default=False,
@@ -241,16 +241,16 @@ def tearDown(self):
241241
asyncio.run(prepare_vllm_baseline_outputs())
242242
exit(0)
243243

244-
if FLAGS.generate_guided_baseline:
245-
guided_decoding_params = {
244+
if FLAGS.generate_structured_baseline:
245+
structured_outputs_params = {
246246
"choice": ["Positive", "Negative"],
247247
}
248-
guided_generation = GuidedDecodingParams(**guided_decoding_params)
248+
structured_generation = StructuredOutputsParams(**structured_outputs_params)
249249
asyncio.run(
250250
prepare_vllm_baseline_outputs(
251-
export_file="vllm_guided_baseline_output.pkl",
252-
prompts=GUIDED_PROMPTS,
253-
guided_generation=guided_generation,
251+
export_file="vllm_structured_baseline_output.pkl",
252+
prompts=STRUCTURED_PROMPTS,
253+
structured_generation=structured_generation,
254254
)
255255
)
256256
exit(0)

ci/L0_backend_vllm/accuracy_test/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ set +e
5151
python3 $CLIENT_PY --generate-baseline >> $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
5252
wait $BASELINE_PID
5353

54-
python3 $CLIENT_PY --generate-guided-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
54+
python3 $CLIENT_PY --generate-structured-baseline > $VLLM_ENGINE_LOG 2>&1 & BASELINE_PID=$!
5555
wait $BASELINE_PID
5656
set -e
5757

src/utils/vllm_backend_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import json
2828
from typing import Optional
2929

30-
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
30+
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
3131

3232

3333
class TritonSamplingParams(SamplingParams):
@@ -84,8 +84,8 @@ def from_dict(
8484
Optional[int]: int,
8585
}
8686
for key, value in params_dict.items():
87-
if key == "guided_decoding":
88-
params_dict[key] = GuidedDecodingParams(**json.loads(value))
87+
if key == "structured_outputs":
88+
params_dict[key] = StructuredOutputsParams(**json.loads(value))
8989
elif key in vllm_params_dict:
9090
vllm_type = vllm_params_dict[key]
9191
if vllm_type in type_mapping:

0 commit comments

Comments
 (0)