1- # Copyright 2023-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # Copyright 2023-2025 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22#
33# Redistribution and use in source and binary forms, with or without
44# modification, are permitted provided that the following conditions
2626
2727import argparse
2828import asyncio
29+ import json
2930import pickle
3031import sys
3132import unittest
3637from vllm import SamplingParams
3738from vllm .engine .arg_utils import AsyncEngineArgs
3839from vllm .engine .async_llm_engine import AsyncLLMEngine
40+ from vllm .sampling_params import GuidedDecodingParams
3941from vllm .utils import random_uuid
4042
4143sys .path .append ("../../common" )
5355 "The future of AI is" ,
5456]
5557
58+ GUIDED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. " ]
59+
5660SAMPLING_PARAMETERS = {"temperature" : 0 , "top_p" : 1 }
5761
5862
59- async def generate_python_vllm_output (prompt , llm_engine ):
63+ async def generate_python_vllm_output (
64+ prompt ,
65+ llm_engine ,
66+ sampling_params = SamplingParams (** SAMPLING_PARAMETERS ),
67+ guided_generation = None ,
68+ ):
6069 request_id = random_uuid ()
61- sampling_params = SamplingParams (** SAMPLING_PARAMETERS )
6270 python_vllm_output = None
6371 last_output = None
72+ if guided_generation :
73+ sampling_params .guided_decoding = guided_generation
6474
6575 async for vllm_output in llm_engine .generate (prompt , sampling_params , request_id ):
6676 last_output = vllm_output
@@ -69,24 +79,28 @@ async def generate_python_vllm_output(prompt, llm_engine):
6979 python_vllm_output = [
7080 (prompt + output .text ).encode ("utf-8" ) for output in last_output .outputs
7181 ]
72-
7382 return python_vllm_output
7483
7584
76- def prepare_vllm_baseline_outputs ():
85+ def prepare_vllm_baseline_outputs (
86+ export_file = "vllm_baseline_output.pkl" , prompts = PROMPTS , guided_generation = None
87+ ):
7788 """
7889 Helper function that starts async vLLM engine and generates output for each
79- prompt in `PROMPTS `. Saves resulted baselines in `vllm_baseline_output.pkl`
90+ prompt in `prompts `. Saves resulted baselines in `vllm_baseline_output.pkl`
8091 for further use.
8192 """
8293 llm_engine = AsyncLLMEngine .from_engine_args (AsyncEngineArgs (** VLLM_ENGINE_CONFIG ))
8394 python_vllm_output = []
84- for i in range (len (PROMPTS )):
95+ for i in range (len (prompts )):
8596 python_vllm_output .extend (
86- asyncio .run (generate_python_vllm_output (PROMPTS [i ], llm_engine ))
97+ asyncio .run (
98+ generate_python_vllm_output (
99+ prompts [i ], llm_engine , guided_generation = guided_generation
100+ )
101+ )
87102 )
88-
89- with open ("vllm_baseline_output.pkl" , "wb" ) as f :
103+ with open (export_file , "wb" ) as f :
90104 pickle .dump (python_vllm_output , f )
91105
92106 return
@@ -96,6 +110,9 @@ class VLLMTritonAccuracyTest(TestResultCollector):
96110 def setUp (self ):
97111 self .triton_client = grpcclient .InferenceServerClient (url = "localhost:8001" )
98112 self .vllm_model_name = "vllm_opt"
113+
114+ def test_vllm_model (self ):
115+ # Reading and verifying baseline data
99116 self .python_vllm_output = []
100117 with open ("vllm_baseline_output.pkl" , "rb" ) as f :
101118 self .python_vllm_output = pickle .load (f )
@@ -116,11 +133,9 @@ def setUp(self):
116133 ),
117134 )
118135
119- def test_vllm_model (self ):
120136 user_data = UserData ()
121137 stream = False
122138 triton_vllm_output = []
123-
124139 self .triton_client .start_stream (callback = partial (callback , user_data ))
125140 for i in range (len (PROMPTS )):
126141 request_data = create_vllm_request (
@@ -131,7 +146,7 @@ def test_vllm_model(self):
131146 request_id = request_data ["request_id" ],
132147 inputs = request_data ["inputs" ],
133148 outputs = request_data ["outputs" ],
134- parameters = SAMPLING_PARAMETERS ,
149+ parameters = request_data [ "parameters" ] ,
135150 )
136151
137152 for i in range (len (PROMPTS )):
@@ -146,6 +161,63 @@ def test_vllm_model(self):
146161 self .triton_client .stop_stream ()
147162 self .assertEqual (self .python_vllm_output .sort (), triton_vllm_output .sort ())
148163
164+ def test_guided_decoding (self ):
165+ # Reading and verifying baseline data
166+ self .python_vllm_output = []
167+ with open ("vllm_guided_baseline_output.pkl" , "rb" ) as f :
168+ self .python_vllm_output = pickle .load (f )
169+
170+ self .assertNotEqual (
171+ self .python_vllm_output ,
172+ [],
173+ "Loaded baseline outputs' list should not be empty" ,
174+ )
175+ self .assertIsNotNone (
176+ self .python_vllm_output , "Loaded baseline outputs' list should not be None"
177+ )
178+ self .assertEqual (
179+ len (self .python_vllm_output ),
180+ len (GUIDED_PROMPTS ),
181+ "Unexpected number of baseline outputs loaded, expected {}, but got {}" .format (
182+ len (GUIDED_PROMPTS ), len (self .python_vllm_output )
183+ ),
184+ )
185+
186+ user_data = UserData ()
187+ stream = False
188+ triton_vllm_output = []
189+
190+ self .triton_client .start_stream (callback = partial (callback , user_data ))
191+ sampling_params = SAMPLING_PARAMETERS
192+ guided_decoding_params = {
193+ "choice" : ["Positive" , "Negative" ],
194+ "backend" : "outlines" ,
195+ }
196+ sampling_params ["guided_decoding" ] = json .dumps (guided_decoding_params )
197+ for i in range (len (GUIDED_PROMPTS )):
198+ request_data = create_vllm_request (
199+ GUIDED_PROMPTS [i ], i , stream , sampling_params , self .vllm_model_name
200+ )
201+ self .triton_client .async_stream_infer (
202+ model_name = self .vllm_model_name ,
203+ request_id = request_data ["request_id" ],
204+ inputs = request_data ["inputs" ],
205+ outputs = request_data ["outputs" ],
206+ parameters = request_data ["parameters" ],
207+ )
208+
209+ for i in range (len (GUIDED_PROMPTS )):
210+ result = user_data ._completed_requests .get ()
211+ self .assertIsNot (type (result ), InferenceServerException , str (result ))
212+
213+ output = result .as_numpy ("text_output" )
214+ self .assertIsNotNone (output , "`text_output` should not be None" )
215+
216+ triton_vllm_output .extend (output )
217+
218+ self .triton_client .stop_stream ()
219+ self .assertEqual (self .python_vllm_output .sort (), triton_vllm_output .sort ())
220+
149221 def tearDown (self ):
150222 self .triton_client .close ()
151223
@@ -159,9 +231,29 @@ def tearDown(self):
159231 default = False ,
160232 help = "Generates baseline output for accuracy tests" ,
161233 )
234+ parser .add_argument (
235+ "--generate-guided-baseline" ,
236+ action = "store_true" ,
237+ required = False ,
238+ default = False ,
239+ help = "Generates baseline output for accuracy tests" ,
240+ )
162241 FLAGS = parser .parse_args ()
163242 if FLAGS .generate_baseline :
164243 prepare_vllm_baseline_outputs ()
165244 exit (0 )
166245
246+ if FLAGS .generate_guided_baseline :
247+ guided_decoding_params = {
248+ "choice" : ["Positive" , "Negative" ],
249+ "backend" : "outlines" ,
250+ }
251+ guided_generation = GuidedDecodingParams (** guided_decoding_params )
252+ prepare_vllm_baseline_outputs (
253+ export_file = "vllm_guided_baseline_output.pkl" ,
254+ prompts = GUIDED_PROMPTS ,
255+ guided_generation = guided_generation ,
256+ )
257+ exit (0 )
258+
167259 unittest .main ()
0 commit comments