Skip to content

Commit 5c03411

Browse files
authored
Add exclude_input_in_output option to vllm backend (#35)
1 parent 6f0afff commit 5c03411

File tree

7 files changed

+300
-38
lines changed

7 files changed

+300
-38
lines changed

ci/L0_backend_vllm/enabled_stream/enabled_stream_test.py

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,113 @@
3434
sys.path.append("../../common")
3535
from test_util import AsyncTestResultCollector, create_vllm_request
3636

37+
PROMPTS = ["The most dangerous animal is"]
38+
SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"}
39+
3740

3841
class VLLMTritonStreamTest(AsyncTestResultCollector):
39-
async def test_vllm_model_enabled_stream(self):
42+
async def _test_vllm_model(
43+
self,
44+
prompts=PROMPTS,
45+
sampling_parameters=SAMPLING_PARAMETERS,
46+
stream=True,
47+
exclude_input_in_output=None,
48+
expected_output=None,
49+
expect_error=False,
50+
):
4051
async with grpcclient.InferenceServerClient(
4152
url="localhost:8001"
4253
) as triton_client:
4354
model_name = "vllm_opt"
44-
stream = True
45-
prompts = [
46-
"The most dangerous animal is",
47-
"The future of AI is",
48-
]
49-
sampling_parameters = {"temperature": "0", "top_p": "1"}
5055

5156
async def request_iterator():
5257
for i, prompt in enumerate(prompts):
5358
yield create_vllm_request(
54-
prompt, i, stream, sampling_parameters, model_name
59+
prompt,
60+
i,
61+
stream,
62+
sampling_parameters,
63+
model_name,
64+
exclude_input_in_output=exclude_input_in_output,
5565
)
5666

5767
response_iterator = triton_client.stream_infer(
5868
inputs_iterator=request_iterator()
5969
)
60-
70+
final_response = []
6171
async for response in response_iterator:
6272
result, error = response
63-
self.assertIsNone(error, str(error))
64-
self.assertIsNotNone(result, str(result))
73+
if expect_error:
74+
self.assertIsInstance(error, InferenceServerException)
75+
self.assertEquals(
76+
error.message(),
77+
"Error generating stream: When streaming, `exclude_input_in_output` = False is not allowed.",
78+
error,
79+
)
80+
return
6581

82+
self.assertIsNone(error, error)
83+
self.assertIsNotNone(result, result)
6684
output = result.as_numpy("text_output")
6785
self.assertIsNotNone(output, "`text_output` should not be None")
86+
final_response.append(str(output[0], encoding="utf-8"))
87+
if expected_output is not None:
88+
self.assertEqual(
89+
final_response,
90+
expected_output,
91+
'Expected to receive the following response: "{}",\
92+
but received "{}".'.format(
93+
expected_output, final_response
94+
),
95+
)
96+
97+
async def test_vllm_model_enabled_stream(self):
98+
"""
99+
Verifying that request with multiple prompts runs successfully.
100+
"""
101+
prompts = [
102+
"The most dangerous animal is",
103+
"The future of AI is",
104+
]
105+
106+
await self._test_vllm_model(prompts=prompts)
107+
108+
async def test_vllm_model_enabled_stream_exclude_input_in_output_default(self):
109+
"""
110+
Verifying that streaming request returns only generated diffs, which
111+
is default behaviour for `stream=True`.
112+
"""
113+
expected_output = [
114+
" the",
115+
" one",
116+
" that",
117+
" is",
118+
" most",
119+
" likely",
120+
" to",
121+
" be",
122+
" killed",
123+
" by",
124+
" a",
125+
" car",
126+
".",
127+
"\n",
128+
"I",
129+
"'m",
130+
]
131+
await self._test_vllm_model(expected_output=expected_output)
132+
133+
async def test_vllm_model_enabled_stream_exclude_input_in_output_false(self):
134+
"""
135+
Verifying that streaming request returns only generated diffs even if
136+
`exclude_input_in_output` is set to False explicitly.
137+
"""
138+
expected_output = "Error generating stream: When streaming, `exclude_input_in_output` = False is not allowed."
139+
await self._test_vllm_model(
140+
exclude_input_in_output=False,
141+
expected_output=expected_output,
142+
expect_error=True,
143+
)
68144

69145

70146
if __name__ == "__main__":

ci/L0_backend_vllm/enabled_stream/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ CLIENT_LOG="./enabled_stream_client.log"
3636
TEST_RESULT_FILE='test_results.txt'
3737
CLIENT_PY="./enabled_stream_test.py"
3838
SAMPLE_MODELS_REPO="../../../samples/model_repository"
39-
EXPECTED_NUM_TESTS=1
39+
EXPECTED_NUM_TESTS=3
4040

4141
rm -rf models && mkdir -p models
4242
cp -r ${SAMPLE_MODELS_REPO}/vllm_model models/vllm_opt

ci/L0_backend_vllm/vllm_backend/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ CLIENT_LOG="./vllm_backend_client.log"
3636
TEST_RESULT_FILE='test_results.txt'
3737
CLIENT_PY="./vllm_backend_test.py"
3838
SAMPLE_MODELS_REPO="../../../samples/model_repository"
39-
EXPECTED_NUM_TESTS=3
39+
EXPECTED_NUM_TESTS=6
4040

4141
# Helpers =======================================
4242
function assert_curl_success {

ci/L0_backend_vllm/vllm_backend/vllm_backend_test.py

Lines changed: 109 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@
3535
sys.path.append("../../common")
3636
from test_util import TestResultCollector, UserData, callback, create_vllm_request
3737

38+
PROMPTS = [
39+
"The most dangerous animal is",
40+
"The capital of France is",
41+
"The future of AI is",
42+
]
43+
SAMPLING_PARAMETERS = {"temperature": "0", "top_p": "1"}
44+
3845

3946
class VLLMTritonBackendTest(TestResultCollector):
4047
def setUp(self):
@@ -60,8 +67,18 @@ def test_vllm_triton_backend(self):
6067
self.assertFalse(self.triton_client.is_model_ready(self.python_model_name))
6168

6269
# Test vllm model and unload vllm model
63-
self._test_vllm_model(send_parameters_as_tensor=True)
64-
self._test_vllm_model(send_parameters_as_tensor=False)
70+
self._test_vllm_model(
71+
prompts=PROMPTS,
72+
sampling_parameters=SAMPLING_PARAMETERS,
73+
stream=False,
74+
send_parameters_as_tensor=True,
75+
)
76+
self._test_vllm_model(
77+
prompts=PROMPTS,
78+
sampling_parameters=SAMPLING_PARAMETERS,
79+
stream=False,
80+
send_parameters_as_tensor=False,
81+
)
6582
self.triton_client.unload_model(self.vllm_model_name)
6683

6784
def test_model_with_invalid_attributes(self):
@@ -74,16 +91,90 @@ def test_vllm_invalid_model_name(self):
7491
with self.assertRaises(InferenceServerException):
7592
self.triton_client.load_model(model_name)
7693

77-
def _test_vllm_model(self, send_parameters_as_tensor):
78-
user_data = UserData()
79-
stream = False
94+
def test_exclude_input_in_output_default(self):
95+
"""
96+
Verifying default behavior for `exclude_input_in_output`
97+
in non-streaming mode.
98+
Expected result: prompt is returned with diffs.
99+
"""
100+
self.triton_client.load_model(self.vllm_model_name)
80101
prompts = [
81-
"The most dangerous animal is",
82102
"The capital of France is",
83-
"The future of AI is",
84103
]
85-
number_of_vllm_reqs = len(prompts)
104+
expected_output = [
105+
b"The capital of France is the capital of the French Republic.\n\nThe capital of France is the capital"
106+
]
107+
sampling_parameters = {"temperature": "0", "top_p": "1"}
108+
self._test_vllm_model(
109+
prompts,
110+
sampling_parameters,
111+
stream=False,
112+
send_parameters_as_tensor=True,
113+
expected_output=expected_output,
114+
)
115+
self.triton_client.unload_model(self.vllm_model_name)
116+
117+
def test_exclude_input_in_output_false(self):
118+
"""
119+
Verifying behavior for `exclude_input_in_output` = False
120+
in non-streaming mode.
121+
Expected result: prompt is returned with diffs.
122+
"""
123+
self.triton_client.load_model(self.vllm_model_name)
124+
# Test vllm model and unload vllm model
125+
prompts = [
126+
"The capital of France is",
127+
]
128+
expected_output = [
129+
b"The capital of France is the capital of the French Republic.\n\nThe capital of France is the capital"
130+
]
131+
sampling_parameters = {"temperature": "0", "top_p": "1"}
132+
self._test_vllm_model(
133+
prompts,
134+
sampling_parameters,
135+
stream=False,
136+
send_parameters_as_tensor=True,
137+
exclude_input_in_output=False,
138+
expected_output=expected_output,
139+
)
140+
self.triton_client.unload_model(self.vllm_model_name)
141+
142+
def test_exclude_input_in_output_true(self):
143+
"""
144+
Verifying behavior for `exclude_input_in_output` = True
145+
in non-streaming mode.
146+
Expected result: only diffs are returned.
147+
"""
148+
self.triton_client.load_model(self.vllm_model_name)
149+
# Test vllm model and unload vllm model
150+
prompts = [
151+
"The capital of France is",
152+
]
153+
expected_output = [
154+
b" the capital of the French Republic.\n\nThe capital of France is the capital"
155+
]
86156
sampling_parameters = {"temperature": "0", "top_p": "1"}
157+
self._test_vllm_model(
158+
prompts,
159+
sampling_parameters,
160+
stream=False,
161+
send_parameters_as_tensor=True,
162+
exclude_input_in_output=True,
163+
expected_output=expected_output,
164+
)
165+
self.triton_client.unload_model(self.vllm_model_name)
166+
167+
def _test_vllm_model(
168+
self,
169+
prompts,
170+
sampling_parameters,
171+
stream,
172+
send_parameters_as_tensor,
173+
exclude_input_in_output=None,
174+
expected_output=None,
175+
):
176+
user_data = UserData()
177+
number_of_vllm_reqs = len(prompts)
87178

88179
self.triton_client.start_stream(callback=partial(callback, user_data))
89180
for i in range(number_of_vllm_reqs):
@@ -94,6 +185,7 @@ def _test_vllm_model(self, send_parameters_as_tensor):
94185
sampling_parameters,
95186
self.vllm_model_name,
96187
send_parameters_as_tensor,
188+
exclude_input_in_output=exclude_input_in_output,
97189
)
98190
self.triton_client.async_stream_infer(
99191
model_name=self.vllm_model_name,
@@ -111,6 +203,15 @@ def _test_vllm_model(self, send_parameters_as_tensor):
111203

112204
output = result.as_numpy("text_output")
113205
self.assertIsNotNone(output, "`text_output` should not be None")
206+
if expected_output is not None:
207+
self.assertEqual(
208+
output,
209+
expected_output[i],
210+
'Actual and expected outputs do not match.\n \
211+
Expected "{}" \n Actual:"{}"'.format(
212+
output, expected_output[i]
213+
),
214+
)
114215

115216
self.triton_client.stop_stream()
116217

ci/common/test_util.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def create_vllm_request(
9595
sampling_parameters,
9696
model_name,
9797
send_parameters_as_tensor=True,
98+
exclude_input_in_output=None,
9899
):
99100
inputs = []
100101

@@ -111,6 +112,10 @@ def create_vllm_request(
111112
inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES"))
112113
inputs[-1].set_data_from_numpy(sampling_parameters_data)
113114

115+
if exclude_input_in_output is not None:
116+
inputs.append(grpcclient.InferInput("exclude_input_in_output", [1], "BOOL"))
117+
inputs[-1].set_data_from_numpy(np.array([exclude_input_in_output], dtype=bool))
118+
114119
outputs = [grpcclient.InferRequestedOutput("text_output")]
115120

116121
return {

0 commit comments

Comments
 (0)