3737from vllm import SamplingParams
3838from vllm .engine .arg_utils import AsyncEngineArgs
3939from vllm .engine .async_llm_engine import AsyncLLMEngine
40- from vllm .sampling_params import GuidedDecodingParams
40+ from vllm .sampling_params import StructuredOutputsParams
4141from vllm .utils import random_uuid
4242
4343sys .path .append ("../../common" )
5555 "The future of AI is" ,
5656]
5757
58- GUIDED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. " ]
58+ STRUCTURED_PROMPTS = ["Classify intent of the sentence: Harry Potter is underrated. " ]
5959
6060SAMPLING_PARAMETERS = {"temperature" : 0 , "top_p" : 1 }
6161
@@ -64,13 +64,13 @@ async def generate_python_vllm_output(
6464 prompt ,
6565 llm_engine ,
6666 sampling_params = SamplingParams (** SAMPLING_PARAMETERS ),
67- guided_generation = None ,
67+ structured_generation = None ,
6868):
6969 request_id = random_uuid ()
7070 python_vllm_output = None
7171 last_output = None
72- if guided_generation :
73- sampling_params .guided_decoding = guided_generation
72+ if structured_generation :
73+ sampling_params .structured_outputs = structured_generation
7474
7575 async for vllm_output in llm_engine .generate (prompt , sampling_params , request_id ):
7676 last_output = vllm_output
@@ -83,7 +83,7 @@ async def generate_python_vllm_output(
8383
8484
8585async def prepare_vllm_baseline_outputs (
86- export_file = "vllm_baseline_output.pkl" , prompts = PROMPTS , guided_generation = None
86+ export_file = "vllm_baseline_output.pkl" , prompts = PROMPTS , structured_generation = None
8787):
8888 """
8989 Helper function that starts async vLLM engine and generates output for each
@@ -94,7 +94,7 @@ async def prepare_vllm_baseline_outputs(
9494 python_vllm_output = []
9595 for i in range (len (prompts )):
9696 output = await generate_python_vllm_output (
97- prompts [i ], llm_engine , guided_generation = guided_generation
97+ prompts [i ], llm_engine , structured_generation = structured_generation
9898 )
9999 if output :
100100 python_vllm_output .extend (output )
@@ -160,10 +160,10 @@ def test_vllm_model(self):
160160 self .triton_client .stop_stream ()
161161 self .assertEqual (self .python_vllm_output .sort (), triton_vllm_output .sort ())
162162
163- def test_guided_decoding (self ):
163+ def test_structured_outputs (self ):
164164 # Reading and verifying baseline data
165165 self .python_vllm_output = []
166- with open ("vllm_guided_baseline_output .pkl" , "rb" ) as f :
166+ with open ("vllm_structured_baseline_output .pkl" , "rb" ) as f :
167167 self .python_vllm_output = pickle .load (f )
168168
169169 self .assertNotEqual (
@@ -176,9 +176,9 @@ def test_guided_decoding(self):
176176 )
177177 self .assertEqual (
178178 len (self .python_vllm_output ),
179- len (GUIDED_PROMPTS ),
179+ len (STRUCTURED_PROMPTS ),
180180 "Unexpected number of baseline outputs loaded, expected {}, but got {}" .format (
181- len (GUIDED_PROMPTS ), len (self .python_vllm_output )
181+ len (STRUCTURED_PROMPTS ), len (self .python_vllm_output )
182182 ),
183183 )
184184
@@ -188,13 +188,13 @@ def test_guided_decoding(self):
188188
189189 self .triton_client .start_stream (callback = partial (callback , user_data ))
190190 sampling_params = SAMPLING_PARAMETERS
191- guided_decoding_params = {
191+ structured_outputs_params = {
192192 "choice" : ["Positive" , "Negative" ],
193193 }
194- sampling_params ["guided_decoding " ] = json .dumps (guided_decoding_params )
195- for i in range (len (GUIDED_PROMPTS )):
194+ sampling_params ["structured_outputs " ] = json .dumps (structured_outputs_params )
195+ for i in range (len (STRUCTURED_PROMPTS )):
196196 request_data = create_vllm_request (
197- GUIDED_PROMPTS [i ], i , stream , sampling_params , self .vllm_model_name
197+ STRUCTURED_PROMPTS [i ], i , stream , sampling_params , self .vllm_model_name
198198 )
199199 self .triton_client .async_stream_infer (
200200 model_name = self .vllm_model_name ,
@@ -204,7 +204,7 @@ def test_guided_decoding(self):
204204 parameters = request_data ["parameters" ],
205205 )
206206
207- for i in range (len (GUIDED_PROMPTS )):
207+ for i in range (len (STRUCTURED_PROMPTS )):
208208 result = user_data ._completed_requests .get ()
209209 self .assertIsNot (type (result ), InferenceServerException , str (result ))
210210
@@ -230,7 +230,7 @@ def tearDown(self):
230230 help = "Generates baseline output for accuracy tests" ,
231231 )
232232 parser .add_argument (
233- "--generate-guided -baseline" ,
233+ "--generate-structured -baseline" ,
234234 action = "store_true" ,
235235 required = False ,
236236 default = False ,
@@ -241,16 +241,16 @@ def tearDown(self):
241241 asyncio .run (prepare_vllm_baseline_outputs ())
242242 exit (0 )
243243
244- if FLAGS .generate_guided_baseline :
245- guided_decoding_params = {
244+ if FLAGS .generate_structured_baseline :
245+ structured_outputs_params = {
246246 "choice" : ["Positive" , "Negative" ],
247247 }
248- guided_generation = GuidedDecodingParams (** guided_decoding_params )
248+ structured_generation = StructuredOutputsParams (** structured_outputs_params )
249249 asyncio .run (
250250 prepare_vllm_baseline_outputs (
251- export_file = "vllm_guided_baseline_output .pkl" ,
252- prompts = GUIDED_PROMPTS ,
253- guided_generation = guided_generation ,
251+ export_file = "vllm_structured_baseline_output .pkl" ,
252+ prompts = STRUCTURED_PROMPTS ,
253+ structured_generation = structured_generation ,
254254 )
255255 )
256256 exit (0 )
0 commit comments