1- # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # Copyright 2023-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22#
33# Redistribution and use in source and binary forms, with or without
44# modification, are permitted provided that the following conditions
2424# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
27+ import argparse
2728import asyncio
29+ import pickle
2830import sys
2931import unittest
3032from functools import partial
3941sys .path .append ("../../common" )
4042from test_util import TestResultCollector , UserData , callback , create_vllm_request
4143
44+ VLLM_ENGINE_CONFIG = {
45+ "model" : "facebook/opt-125m" ,
46+ "gpu_memory_utilization" : 0.3 ,
47+ }
48+
49+
50+ PROMPTS = [
51+ "The most dangerous animal is" ,
52+ "The capital of France is" ,
53+ "The future of AI is" ,
54+ ]
55+
56+ SAMPLING_PARAMETERS = {"temperature" : 0 , "top_p" : 1 }
57+
4258
4359async def generate_python_vllm_output (prompt , llm_engine ):
4460 request_id = random_uuid ()
45- sampling_parameters = {"temperature" : 0 , "top_p" : 1 }
46- sampling_params = SamplingParams (** sampling_parameters )
47-
61+ sampling_params = SamplingParams (** SAMPLING_PARAMETERS )
4862 python_vllm_output = None
4963 last_output = None
5064
@@ -59,50 +73,68 @@ async def generate_python_vllm_output(prompt, llm_engine):
5973 return python_vllm_output
6074
6175
76+ def prepare_vllm_baseline_outputs ():
77+ """
78+ Helper function that starts async vLLM engine and generates output for each
79+ prompt in `PROMPTS`. Saves resulted baselines in `vllm_baseline_output.pkl`
80+ for further use.
81+ """
82+ llm_engine = AsyncLLMEngine .from_engine_args (AsyncEngineArgs (** VLLM_ENGINE_CONFIG ))
83+ python_vllm_output = []
84+ for i in range (len (PROMPTS )):
85+ python_vllm_output .extend (
86+ asyncio .run (generate_python_vllm_output (PROMPTS [i ], llm_engine ))
87+ )
88+
89+ with open ("vllm_baseline_output.pkl" , "wb" ) as f :
90+ pickle .dump (python_vllm_output , f )
91+
92+ return
93+
94+
6295class VLLMTritonAccuracyTest (TestResultCollector ):
6396 def setUp (self ):
6497 self .triton_client = grpcclient .InferenceServerClient (url = "localhost:8001" )
65- vllm_engine_config = {
66- "model" : "facebook/opt-125m" ,
67- "gpu_memory_utilization" : 0.3 ,
68- }
69-
70- self .llm_engine = AsyncLLMEngine .from_engine_args (
71- AsyncEngineArgs (** vllm_engine_config )
72- )
7398 self .vllm_model_name = "vllm_opt"
99+ self .python_vllm_output = []
100+ with open ("vllm_baseline_output.pkl" , "rb" ) as f :
101+ self .python_vllm_output = pickle .load (f )
102+
103+ self .assertNotEqual (
104+ self .python_vllm_output ,
105+ [],
106+ "Loaded baseline outputs' list should not be empty" ,
107+ )
108+ self .assertIsNotNone (
109+ self .python_vllm_output , "Loaded baseline outputs' list should not be None"
110+ )
111+ self .assertEqual (
112+ len (self .python_vllm_output ),
113+ len (PROMPTS ),
114+ "Unexpected number of baseline outputs loaded, expected {}, but got {}" .format (
115+ len (PROMPTS ), len (self .python_vllm_output )
116+ ),
117+ )
74118
75119 def test_vllm_model (self ):
76120 user_data = UserData ()
77121 stream = False
78- prompts = [
79- "The most dangerous animal is" ,
80- "The capital of France is" ,
81- "The future of AI is" ,
82- ]
83- number_of_vllm_reqs = len (prompts )
84- sampling_parameters = {"temperature" : "0" , "top_p" : "1" }
85- python_vllm_output = []
86122 triton_vllm_output = []
87123
88124 self .triton_client .start_stream (callback = partial (callback , user_data ))
89- for i in range (number_of_vllm_reqs ):
125+ for i in range (len ( PROMPTS ) ):
90126 request_data = create_vllm_request (
91- prompts [i ], i , stream , sampling_parameters , self .vllm_model_name
127+ PROMPTS [i ], i , stream , SAMPLING_PARAMETERS , self .vllm_model_name
92128 )
93129 self .triton_client .async_stream_infer (
94130 model_name = self .vllm_model_name ,
95131 request_id = request_data ["request_id" ],
96132 inputs = request_data ["inputs" ],
97133 outputs = request_data ["outputs" ],
98- parameters = sampling_parameters ,
99- )
100-
101- python_vllm_output .extend (
102- asyncio .run (generate_python_vllm_output (prompts [i ], self .llm_engine ))
134+ parameters = SAMPLING_PARAMETERS ,
103135 )
104136
105- for i in range (number_of_vllm_reqs ):
137+ for i in range (len ( PROMPTS ) ):
106138 result = user_data ._completed_requests .get ()
107139 self .assertIsNot (type (result ), InferenceServerException , str (result ))
108140
@@ -112,11 +144,24 @@ def test_vllm_model(self):
112144 triton_vllm_output .extend (output )
113145
114146 self .triton_client .stop_stream ()
115- self .assertEqual (python_vllm_output , triton_vllm_output )
147+ self .assertEqual (self . python_vllm_output . sort () , triton_vllm_output . sort () )
116148
117149 def tearDown (self ):
118150 self .triton_client .close ()
119151
120152
121153if __name__ == "__main__" :
154+ parser = argparse .ArgumentParser ()
155+ parser .add_argument (
156+ "--generate-baseline" ,
157+ action = "store_true" ,
158+ required = False ,
159+ default = False ,
160+ help = "Generates baseline output for accuracy tests" ,
161+ )
162+ FLAGS = parser .parse_args ()
163+ if FLAGS .generate_baseline :
164+ prepare_vllm_baseline_outputs ()
165+ exit (0 )
166+
122167 unittest .main ()
0 commit comments