3535sys .path .append ("../../common" )
3636from test_util import TestResultCollector , UserData , callback , create_vllm_request
3737
38+ PROMPTS = [
39+ "The most dangerous animal is" ,
40+ "The capital of France is" ,
41+ "The future of AI is" ,
42+ ]
43+ SAMPLING_PARAMETERS = {"temperature" : "0" , "top_p" : "1" }
44+
3845
3946class VLLMTritonBackendTest (TestResultCollector ):
4047 def setUp (self ):
@@ -60,8 +67,18 @@ def test_vllm_triton_backend(self):
6067 self .assertFalse (self .triton_client .is_model_ready (self .python_model_name ))
6168
6269 # Test vllm model and unload vllm model
63- self ._test_vllm_model (send_parameters_as_tensor = True )
64- self ._test_vllm_model (send_parameters_as_tensor = False )
70+ self ._test_vllm_model (
71+ prompts = PROMPTS ,
72+ sampling_parameters = SAMPLING_PARAMETERS ,
73+ stream = False ,
74+ send_parameters_as_tensor = True ,
75+ )
76+ self ._test_vllm_model (
77+ prompts = PROMPTS ,
78+ sampling_parameters = SAMPLING_PARAMETERS ,
79+ stream = False ,
80+ send_parameters_as_tensor = False ,
81+ )
6582 self .triton_client .unload_model (self .vllm_model_name )
6683
6784 def test_model_with_invalid_attributes (self ):
@@ -74,16 +91,90 @@ def test_vllm_invalid_model_name(self):
7491 with self .assertRaises (InferenceServerException ):
7592 self .triton_client .load_model (model_name )
7693
77- def _test_vllm_model (self , send_parameters_as_tensor ):
78- user_data = UserData ()
79- stream = False
94+ def test_exclude_input_in_output_default (self ):
95+ """
96+ Verifying default behavior for `exclude_input_in_output`
97+ in non-streaming mode.
98+ Expected result: prompt is returned with diffs.
99+ """
100+ self .triton_client .load_model (self .vllm_model_name )
80101 prompts = [
81- "The most dangerous animal is" ,
82102 "The capital of France is" ,
83- "The future of AI is" ,
84103 ]
85- number_of_vllm_reqs = len (prompts )
104+ expected_output = [
105+ b"The capital of France is the capital of the French Republic.\n \n The capital of France is the capital"
106+ ]
107+ sampling_parameters = {"temperature" : "0" , "top_p" : "1" }
108+ self ._test_vllm_model (
109+ prompts ,
110+ sampling_parameters ,
111+ stream = False ,
112+ send_parameters_as_tensor = True ,
113+ expected_output = expected_output ,
114+ )
115+ self .triton_client .unload_model (self .vllm_model_name )
116+
117+ def test_exclude_input_in_output_false (self ):
118+ """
119+ Verifying behavior for `exclude_input_in_output` = False
120+ in non-streaming mode.
121+ Expected result: prompt is returned with diffs.
122+ """
123+ self .triton_client .load_model (self .vllm_model_name )
124+ # Test vllm model and unload vllm model
125+ prompts = [
126+ "The capital of France is" ,
127+ ]
128+ expected_output = [
129+ b"The capital of France is the capital of the French Republic.\n \n The capital of France is the capital"
130+ ]
131+ sampling_parameters = {"temperature" : "0" , "top_p" : "1" }
132+ self ._test_vllm_model (
133+ prompts ,
134+ sampling_parameters ,
135+ stream = False ,
136+ send_parameters_as_tensor = True ,
137+ exclude_input_in_output = False ,
138+ expected_output = expected_output ,
139+ )
140+ self .triton_client .unload_model (self .vllm_model_name )
141+
142+ def test_exclude_input_in_output_true (self ):
143+ """
144+ Verifying behavior for `exclude_input_in_output` = True
145+ in non-streaming mode.
146+ Expected result: only diffs are returned.
147+ """
148+ self .triton_client .load_model (self .vllm_model_name )
149+ # Test vllm model and unload vllm model
150+ prompts = [
151+ "The capital of France is" ,
152+ ]
153+ expected_output = [
154+ b" the capital of the French Republic.\n \n The capital of France is the capital"
155+ ]
86156 sampling_parameters = {"temperature" : "0" , "top_p" : "1" }
157+ self ._test_vllm_model (
158+ prompts ,
159+ sampling_parameters ,
160+ stream = False ,
161+ send_parameters_as_tensor = True ,
162+ exclude_input_in_output = True ,
163+ expected_output = expected_output ,
164+ )
165+ self .triton_client .unload_model (self .vllm_model_name )
166+
167+ def _test_vllm_model (
168+ self ,
169+ prompts ,
170+ sampling_parameters ,
171+ stream ,
172+ send_parameters_as_tensor ,
173+ exclude_input_in_output = None ,
174+ expected_output = None ,
175+ ):
176+ user_data = UserData ()
177+ number_of_vllm_reqs = len (prompts )
87178
88179 self .triton_client .start_stream (callback = partial (callback , user_data ))
89180 for i in range (number_of_vllm_reqs ):
@@ -94,6 +185,7 @@ def _test_vllm_model(self, send_parameters_as_tensor):
94185 sampling_parameters ,
95186 self .vllm_model_name ,
96187 send_parameters_as_tensor ,
188+ exclude_input_in_output = exclude_input_in_output ,
97189 )
98190 self .triton_client .async_stream_infer (
99191 model_name = self .vllm_model_name ,
@@ -111,6 +203,15 @@ def _test_vllm_model(self, send_parameters_as_tensor):
111203
112204 output = result .as_numpy ("text_output" )
113205 self .assertIsNotNone (output , "`text_output` should not be None" )
206+ if expected_output is not None :
207+ self .assertEqual (
208+ output ,
209+ expected_output [i ],
210+ 'Actual and expected outputs do not match.\n \
211+ Expected "{}" \n Actual:"{}"' .format (
212+ output , expected_output [i ]
213+ ),
214+ )
114215
115216 self .triton_client .stop_stream ()
116217
0 commit comments